diff --git a/.gitignore b/.gitignore index cbfef3fe..b2e76af0 100644 --- a/.gitignore +++ b/.gitignore @@ -7,6 +7,7 @@ dist build _build .cache +venv # Installer logs pip-log.txt @@ -18,7 +19,11 @@ nosetests.xml .DS_Store .idea/* +.python-version +.pytest_cache +/setup.py +/requirements.txt /test.py /test_*.py app.py @@ -34,3 +39,5 @@ profile.html benchmark.py results.json *.so +pyproject.lock +poetry.lock diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..7d92bbd3 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,6 @@ +repos: +- repo: https://github.com/ambv/black + rev: stable + hooks: + - id: black + python_version: python3.7 diff --git a/.travis.yml b/.travis.yml index b761169a..c4572bc3 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,32 +1,58 @@ language: python -# Once mysql 5.6 is available on the container infra, this can be removed. -sudo: required -dist: trusty +stages: + - linting + - test -python: - - "2.7" - - "3.5" +cache: + pip: true + directories: + - $HOME/.cache/pypoetry -env: - - MYSQL_PACKAGE=pymysql - - MYSQL_PACKAGE=mysqlclient +services: + - mysql addons: - postgresql: '9.4' - -before_install: - # Manually install mysql 5.6 since the default is v5.5. - - sudo apt-get update -qq - - sudo apt-get install -qq mysql-server-5.6 mysql-client-5.6 mysql-client-core-5.6 + postgresql: '9.6' install: - - pip install -r tests-requirements.txt -U - - if [[ $MYSQL_PACKAGE == 'pymysql' ]]; then pip install pymysql; fi - - if [[ $MYSQL_PACKAGE == 'mysqlclient' ]]; then pip install mysqlclient; fi + - curl -fsS -o get-poetry.py https://raw.githubusercontent.com/sdispater/poetry/master/get-poetry.py + - python get-poetry.py --preview -y + - source $HOME/.poetry/env + - if [[ $MYSQL_PACKAGE == 'pymysql' ]]; then poetry install --extras mysql-python --extras pgsql; fi + - if [[ $MYSQL_PACKAGE == 'mysqlclient' ]]; then poetry install --extras mysql --extras pgsql; fi -script: py.test tests/ +script: pytest tests/ before_script: - psql -c 'create database orator_test;' -U postgres - - mysql -u root -e 'create database orator_test;' + - mysql -e 'create database orator_test;' + +jobs: + include: + - python: "2.7" + env: MYSQL_PACKAGE=pymysql + - python: "2.7" + env: MYSQL_PACKAGE=mysqlclient + - python: "3.5" + env: MYSQL_PACKAGE=pymysql + - python: "3.5" + env: MYSQL_PACKAGE=mysqlclient + - python: "3.6" + env: MYSQL_PACKAGE=pymysql + - python: "3.6" + env: MYSQL_PACKAGE=mysqlclient + - python: "3.7" + env: MYSQL_PACKAGE=pymysql + dist: xenial + - python: "3.7" + env: MYSQL_PACKAGE=mysqlclient + dist: xenial + + - stage: linting + python: "3.6" + install: + - pip install pre-commit + - pre-commit install-hooks + script: + - pre-commit run --all-files diff --git a/CHANGELOG.md b/CHANGELOG.md index af811624..220ee067 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,80 @@ # Change Log +## [0.9.9] - 2019-07-15 + +### Fixed + +- Fixed missing relationships when eager loading multiple nested relationships. +- Fixed a possible `AttributeError` when starting a transaction. +- Fixed an infinite recursion when using `where_exists()` on a soft-deletable model. +- Fixed some cases where a reconnection would not occur for PostgreSQL. + + +## [0.9.8] - 2018-10-10 + +### Fixed + +- Fixed the `morphed_by_many()` decorator. +- Fixed decoding errors for MySQL. +- Fixed connection errors check. +- Fixed the `touches()` method. +- Fixed `has_many` not showing `DISTINCT`. +- Fixed `save_many()` for Python 3. +- Fixed an error when listing columns for recent MySQL versions. + + +## [0.9.7] - 2017-05-17 + +### Fixed + +- Fixed `orator` command no longer working + + +## [0.9.6] - 2017-05-16 + +### Added + +- Added support for `DATE` types in models. +- Added support for fractional seconds for the `TIMESTAMP` type in MySQL 5.6.4+. +- Added support for fractional seconds for the `TIMESTAMP` and `TIME` types in PostreSQL. + +### Changed + +- Improved implementation of the `chunk` method. + +### Fixed + +- Fixed timezone offset errors when inserting datetime aware objects into PostgreSQL. +- Fixed a bug occurring when using `__touches__` with an optional relationship. +- Fixed collections serialization when using the query builder + + +## [0.9.5] - 2017-02-11 + +### Changed + +- `make:migration` now shows the name of the created migration file. (Thanks to [denislins](https://github.com/denislins)) + +### Fixed + +- Fixed transactions not working for PostgreSQL and SQLite. + + +## [0.9.4] - 2017-01-12 + +### Fixed + +- Fixes `BelongsTo.associate()` for non saved models. +- Fixes reconnection for PostgreSQL. +- Fixes dependencies (changed `fake-factory` to `Faker`) (thanks to [acristoffers](https://github.com/acristoffers)) + + +## [0.9.3] - 2016-11-10 + +### Fixed + +- Fixes `compile_table_exists()` method in PostgreSQL schema grammar that could break migrations. + ## [0.9.2] - 2016-10-17 @@ -361,7 +436,14 @@ Initial release - +[Unreleased]: https://github.com/sdispater/orator/compare/0.9.9...0.9 +[0.9.9]: https://github.com/sdispater/orator/releases/0.9.9 +[0.9.8]: https://github.com/sdispater/orator/releases/0.9.8 +[0.9.7]: https://github.com/sdispater/orator/releases/0.9.7 +[0.9.6]: https://github.com/sdispater/orator/releases/0.9.6 +[0.9.5]: https://github.com/sdispater/orator/releases/0.9.5 +[0.9.4]: https://github.com/sdispater/orator/releases/0.9.4 +[0.9.3]: https://github.com/sdispater/orator/releases/0.9.3 [0.9.2]: https://github.com/sdispater/orator/releases/0.9.2 [0.9.1]: https://github.com/sdispater/orator/releases/0.9.1 [0.9.0]: https://github.com/sdispater/orator/releases/0.9.0 diff --git a/MANIFEST.in b/MANIFEST.in deleted file mode 100644 index c602245e..00000000 --- a/MANIFEST.in +++ /dev/null @@ -1,6 +0,0 @@ -include README.rst LICENSE requirements.txt test-requirements.txt -recursive-exclude tests * -recursive-exclude benchmark * -recursive-exclude seeders * -recursive-exclude seeds * -recursive-exclude migrations * diff --git a/Makefile b/Makefile index 9fd505fd..26c890c4 100644 --- a/Makefile +++ b/Makefile @@ -47,7 +47,7 @@ setup-mysql: drop-mysql TO 'orator'@'localhost';" drop-mysql: - @type -p psql > /dev/null || { echo 'Install and setup PostgreSQL'; exit 1; } + @type -p mysql > /dev/null || { echo 'Install and setup MySQL'; exit 1; } @-mysql -u root -e 'DROP DATABASE orator_test;' > /dev/null 2>&1 @-mysql -u root -e "DROP USER 'orator'@'localhost';" > /dev/null 2>&1 @@ -62,4 +62,6 @@ test: # run tests against all supported python versions tox: + @poet make:setup @tox + @rm -f setup.py diff --git a/README.rst b/README.rst index a055a5dd..afffd0e9 100644 --- a/README.rst +++ b/README.rst @@ -361,7 +361,7 @@ Note that entire collections of models can also be converted to dictionaries: .. code-block:: python - return User.all().serailize() + return User.all().serialize() Converting a model to JSON diff --git a/orator/__init__.py b/orator/__init__.py index dba4d8b3..0316577d 100644 --- a/orator/__init__.py +++ b/orator/__init__.py @@ -1,5 +1,7 @@ # -*- coding: utf-8 -*- +__version__ = "0.9.9" + from .orm import Model, SoftDeletes, Collection, accessor, mutator, scope from .database_manager import DatabaseManager from .query.expression import QueryExpression diff --git a/orator/commands/__init__.py b/orator/commands/__init__.py index 633f8661..40a96afc 100644 --- a/orator/commands/__init__.py +++ b/orator/commands/__init__.py @@ -1,2 +1 @@ # -*- coding: utf-8 -*- - diff --git a/orator/commands/application.py b/orator/commands/application.py index c9313ffd..9f1435f1 100644 --- a/orator/commands/application.py +++ b/orator/commands/application.py @@ -1,15 +1,19 @@ # -*- coding: utf-8 -*- from cleo import Application -from ..version import VERSION +from .. import __version__ -application = Application('Orator', VERSION, complete=True) +application = Application("Orator", __version__, complete=True) # Migrations from .migrations import ( - InstallCommand, MigrateCommand, - MigrateMakeCommand, RollbackCommand, - StatusCommand, ResetCommand, RefreshCommand + InstallCommand, + MigrateCommand, + MigrateMakeCommand, + RollbackCommand, + StatusCommand, + ResetCommand, + RefreshCommand, ) application.add(InstallCommand()) diff --git a/orator/commands/command.py b/orator/commands/command.py index 19351d56..2a114400 100644 --- a/orator/commands/command.py +++ b/orator/commands/command.py @@ -23,18 +23,18 @@ def configure(self): if self.needs_config and not self.resolver: # Checking if a default config file is present if not self._check_config(): - self.add_option('config', 'c', - InputOption.VALUE_REQUIRED, - 'The config file path') + self.add_option( + "config", "c", InputOption.VALUE_REQUIRED, "The config file path" + ) def execute(self, i, o): """ Executes the command. """ - self.set_style('question', fg='blue') + self.set_style("question", fg="blue") if self.needs_config and not self.resolver: - self._handle_config(self.option('config')) + self._handle_config(self.option("config")) return self.handle() @@ -50,28 +50,24 @@ def call_silent(self, name, options=None): return super(Command, self).call_silent(name, options) - def confirm(self, question, default=False, true_answer_regex='(?i)^y'): - """ - Confirm a question with the user. + def confirm_to_proceed(self, message=None): + if message is None: + message = "Do you really wish to run this command?: " - :param question: The question to ask - :type question: str + if self.option("force"): + return True - :param default: The default value - :type default: bool + confirmed = self.confirm(message) - :param true_answer_regex: A regex to match the "yes" answer - :type true_answer_regex: str + if not confirmed: + self.comment("Command Cancelled!") - :rtype: bool - """ - if not self.input.is_interactive(): - return True + return False - return super(Command, self).confirm(question, default=False, true_answer_regex='(?i)^y') + return True def _get_migration_path(self): - return os.path.join(os.getcwd(), 'migrations') + return os.path.join(os.getcwd(), "migrations") def _check_config(self): """ @@ -81,7 +77,7 @@ def _check_config(self): """ current_path = os.path.relpath(os.getcwd()) - accepted_files = ['orator.yml', 'orator.py'] + accepted_files = ["orator.yml", "orator.py"] for accepted_file in accepted_files: config_file = os.path.join(current_path, accepted_file) if os.path.exists(config_file): @@ -101,7 +97,9 @@ def _handle_config(self, config_file): """ config = self._get_config(config_file) - self.resolver = DatabaseManager(config.get('databases', config.get('DATABASES', {}))) + self.resolver = DatabaseManager( + config.get("databases", config.get("DATABASES", {})) + ) return True @@ -111,22 +109,22 @@ def _get_config(self, path=None): :rtype: dict """ - if not path and not self.option('config'): - raise Exception('The --config|-c option is missing.') + if not path and not self.option("config"): + raise Exception("The --config|-c option is missing.") if not path: - path = self.option('config') + path = self.option("config") filename, ext = os.path.splitext(path) - if ext in ['.yml', '.yaml']: + if ext in [".yml", ".yaml"]: with open(path) as fd: config = yaml.load(fd) - elif ext in ['.py']: + elif ext in [".py"]: config = {} with open(path) as fh: exec(fh.read(), {}, config) else: - raise RuntimeError('Config file [%s] is not supported.' % path) + raise RuntimeError("Config file [%s] is not supported." % path) return config diff --git a/orator/commands/migrations/base_command.py b/orator/commands/migrations/base_command.py index 85188c84..ba7ec414 100644 --- a/orator/commands/migrations/base_command.py +++ b/orator/commands/migrations/base_command.py @@ -6,6 +6,5 @@ class BaseCommand(Command): - def _get_migration_path(self): - return os.path.join(os.getcwd(), 'migrations') + return os.path.join(os.getcwd(), "migrations") diff --git a/orator/commands/migrations/install_command.py b/orator/commands/migrations/install_command.py index 161a5485..b8738698 100644 --- a/orator/commands/migrations/install_command.py +++ b/orator/commands/migrations/install_command.py @@ -16,10 +16,10 @@ def handle(self): """ Executes the command """ - database = self.option('database') - repository = DatabaseMigrationRepository(self.resolver, 'migrations') + database = self.option("database") + repository = DatabaseMigrationRepository(self.resolver, "migrations") repository.set_source(database) repository.create_repository() - self.info('Migration table created successfully') + self.info("Migration table created successfully") diff --git a/orator/commands/migrations/make_command.py b/orator/commands/migrations/make_command.py index d0a46e38..e5c929c0 100644 --- a/orator/commands/migrations/make_command.py +++ b/orator/commands/migrations/make_command.py @@ -24,20 +24,20 @@ def handle(self): """ creator = MigrationCreator() - name = self.argument('name') - table = self.option('table') - create = bool(self.option('create')) + name = self.argument("name") + table = self.option("table") + create = bool(self.option("create")) if not table and create is not False: table = create - path = self.option('path') + path = self.option("path") if path is None: path = self._get_migration_path() - self._write_migration(creator, name, table, create, path) + migration_name = self._write_migration(creator, name, table, create, path) - self.info('Migration created successfully') + self.line("Created migration: {}".format(migration_name)) def _write_migration(self, creator, name, table, create, path): """ diff --git a/orator/commands/migrations/migrate_command.py b/orator/commands/migrations/migrate_command.py index 9faef082..24dec942 100644 --- a/orator/commands/migrations/migrate_command.py +++ b/orator/commands/migrations/migrate_command.py @@ -15,26 +15,25 @@ class MigrateCommand(BaseCommand): {--seed-path= : The path of seeds files to be executed. Defaults to ./seeders.} {--P|pretend : Dump the SQL queries that would be run.} + {--f|force : Force the operation to run.} """ def handle(self): - confirm = self.confirm( - 'Are you sure you want to proceed with the migration? ', - False - ) - if not confirm: + if not self.confirm_to_proceed( + "Are you sure you want to proceed with the migration? " + ): return - database = self.option('database') - repository = DatabaseMigrationRepository(self.resolver, 'migrations') + database = self.option("database") + repository = DatabaseMigrationRepository(self.resolver, "migrations") migrator = Migrator(repository, self.resolver) self._prepare_database(migrator, database) - pretend = self.option('pretend') + pretend = self.option("pretend") - path = self.option('path') + path = self.option("path") if path is None: path = self._get_migration_path() @@ -46,21 +45,19 @@ def handle(self): # If the "seed" option has been given, we will rerun the database seed task # to repopulate the database. - if self.option('seed'): - options = [ - ('-n', True) - ] + if self.option("seed"): + options = [("--force", self.option("force"))] if database: - options.append(('--database', database)) + options.append(("--database", database)) - if self.get_definition().has_option('config'): - options.append(('--config', self.option('config'))) + if self.get_definition().has_option("config"): + options.append(("--config", self.option("config"))) - if self.option('seed-path'): - options.append(('--path', self.option('seed-path'))) + if self.option("seed-path"): + options.append(("--path", self.option("seed-path"))) - self.call('db:seed', options) + self.call("db:seed", options) def _prepare_database(self, migrator, database): migrator.set_connection(database) @@ -69,9 +66,9 @@ def _prepare_database(self, migrator, database): options = [] if database: - options.append(('--database', database)) + options.append(("--database", database)) - if self.get_definition().has_option('config'): - options.append(('--config', self.option('config'))) + if self.get_definition().has_option("config"): + options.append(("--config", self.option("config"))) - self.call('migrate:install', options) + self.call("migrate:install", options) diff --git a/orator/commands/migrations/refresh_command.py b/orator/commands/migrations/refresh_command.py index d9c44028..7dfded18 100644 --- a/orator/commands/migrations/refresh_command.py +++ b/orator/commands/migrations/refresh_command.py @@ -12,59 +12,53 @@ class RefreshCommand(BaseCommand): {--p|path= : The path of migrations files to be executed.} {--s|seed : Indicates if the seed task should be re-run.} {--seed-path= : The path of seeds files to be executed. - Defaults to ./seeders.} + Defaults to ./seeds.} {--seeder=database_seeder : The name of the root seeder.} + {--f|force : Force the operation to run.} """ def handle(self): """ Executes the command. """ - confirm = self.confirm( - 'Are you sure you want to refresh the database? ', - False - ) - if not confirm: + if not self.confirm_to_proceed( + "Are you sure you want to refresh the database?: " + ): return - database = self.option('database') + database = self.option("database") - options = [ - ('-n', True) - ] + options = [("--force", True)] - if self.option('path'): - options.append(('--path', self.option('path'))) + if self.option("path"): + options.append(("--path", self.option("path"))) if database: - options.append(('--database', database)) + options.append(("--database", database)) - if self.get_definition().has_option('config'): - options.append(('--config', self.option('config'))) + if self.get_definition().has_option("config"): + options.append(("--config", self.option("config"))) - self.call('migrate:reset', options) + self.call("migrate:reset", options) - self.call('migrate', options) + self.call("migrate", options) if self._needs_seeding(): self._run_seeder(database) def _needs_seeding(self): - return self.option('seed') or self.option('seeder') + return self.option("seed") def _run_seeder(self, database): - options = [ - ('--seeder', self.option('seeder')), - ('-n', True) - ] + options = [("--seeder", self.option("seeder")), ("--force", True)] if database: - options.append(('--database', database)) + options.append(("--database", database)) - if self.get_definition().has_option('config'): - options.append(('--config', self.option('config'))) + if self.get_definition().has_option("config"): + options.append(("--config", self.option("config"))) - if self.option('seed-path'): - options.append(('--path', self.option('seed-path'))) + if self.option("seed-path"): + options.append(("--path", self.option("seed-path"))) - self.call('db:seed', options) + self.call("db:seed", options) diff --git a/orator/commands/migrations/reset_command.py b/orator/commands/migrations/reset_command.py index 92fc5bc5..55978eec 100644 --- a/orator/commands/migrations/reset_command.py +++ b/orator/commands/migrations/reset_command.py @@ -12,29 +12,28 @@ class ResetCommand(BaseCommand): {--d|database= : The database connection to use.} {--p|path= : The path of migrations files to be executed.} {--P|pretend : Dump the SQL queries that would be run.} + {--f|force : Force the operation to run.} """ def handle(self): """ Executes the command. """ - confirm = self.confirm( - 'Are you sure you want to reset all of the migrations? ', - False - ) - if not confirm: + if not self.confirm_to_proceed( + "Are you sure you want to reset all of the migrations?: " + ): return - database = self.option('database') - repository = DatabaseMigrationRepository(self.resolver, 'migrations') + database = self.option("database") + repository = DatabaseMigrationRepository(self.resolver, "migrations") migrator = Migrator(repository, self.resolver) self._prepare_database(migrator, database) - pretend = bool(self.option('pretend')) + pretend = bool(self.option("pretend")) - path = self.option('path') + path = self.option("path") if path is None: path = self._get_migration_path() diff --git a/orator/commands/migrations/rollback_command.py b/orator/commands/migrations/rollback_command.py index 47607872..2b789d65 100644 --- a/orator/commands/migrations/rollback_command.py +++ b/orator/commands/migrations/rollback_command.py @@ -12,29 +12,28 @@ class RollbackCommand(BaseCommand): {--d|database= : The database connection to use.} {--p|path= : The path of migrations files to be executed.} {--P|pretend : Dump the SQL queries that would be run.} + {--f|force : Force the operation to run.} """ def handle(self): """ Executes the command. """ - confirm = self.confirm( - 'Are you sure you want to rollback the last migration? ', - True - ) - if not confirm: + if not self.confirm_to_proceed( + "Are you sure you want to rollback the last migration?: " + ): return - database = self.option('database') - repository = DatabaseMigrationRepository(self.resolver, 'migrations') + database = self.option("database") + repository = DatabaseMigrationRepository(self.resolver, "migrations") migrator = Migrator(repository, self.resolver) self._prepare_database(migrator, database) - pretend = self.option('pretend') + pretend = self.option("pretend") - path = self.option('path') + path = self.option("path") if path is None: path = self._get_migration_path() diff --git a/orator/commands/migrations/status_command.py b/orator/commands/migrations/status_command.py index bd74f881..7560c480 100644 --- a/orator/commands/migrations/status_command.py +++ b/orator/commands/migrations/status_command.py @@ -17,20 +17,20 @@ def handle(self): """ Executes the command. """ - database = self.option('database') + database = self.option("database") self.resolver.set_default_connection(database) - repository = DatabaseMigrationRepository(self.resolver, 'migrations') + repository = DatabaseMigrationRepository(self.resolver, "migrations") migrator = Migrator(repository, self.resolver) if not migrator.repository_exists(): - return self.error('No migrations found') + return self.error("No migrations found") self._prepare_database(migrator, database) - path = self.option('path') + path = self.option("path") if path is None: path = self._get_migration_path() @@ -40,18 +40,15 @@ def handle(self): migrations = [] for migration in migrator._get_migration_files(path): if migration in ran: - migrations.append(['%s' % migration, 'Yes']) + migrations.append(["%s" % migration, "Yes"]) else: - migrations.append(['%s' % migration, 'No']) + migrations.append(["%s" % migration, "No"]) if migrations: - table = self.table( - ['Migration', 'Ran?'], - migrations - ) + table = self.table(["Migration", "Ran?"], migrations) table.render() else: - return self.error('No migrations found') + return self.error("No migrations found") for note in migrator.get_notes(): self.line(note) diff --git a/orator/commands/models/make_command.py b/orator/commands/models/make_command.py index 1280dc31..22be9f5b 100644 --- a/orator/commands/models/make_command.py +++ b/orator/commands/models/make_command.py @@ -18,39 +18,39 @@ class ModelMakeCommand(Command): """ def handle(self): - name = self.argument('name') + name = self.argument("name") singular = inflection.singularize(inflection.tableize(name)) directory = self._get_path() - filepath = self._get_path(singular + '.py') + filepath = self._get_path(singular + ".py") if os.path.exists(filepath): - raise RuntimeError('The model file already exists.') + raise RuntimeError("The model file already exists.") mkdir_p(directory) - parent = os.path.join(directory, '__init__.py') + parent = os.path.join(directory, "__init__.py") if not os.path.exists(parent): - with open(parent, 'w'): + with open(parent, "w"): pass stub = self._get_stub() stub = self._populate_stub(name, stub) - with open(filepath, 'w') as f: + with open(filepath, "w") as f: f.write(stub) - self.info('Model %s successfully created.' % name) + self.info("Model %s successfully created." % name) - if self.option('migration'): + if self.option("migration"): table = inflection.tableize(name) self.call( - 'make:migration', + "make:migration", [ - ('name', 'create_%s_table' % table), - ('--table', table), - ('--create', True) - ] + ("name", "create_%s_table" % table), + ("--table", table), + ("--create", True), + ], ) def _get_stub(self): @@ -73,15 +73,15 @@ def _populate_stub(self, name, stub): :rtype: str """ - stub = stub.replace('DummyClass', name) + stub = stub.replace("DummyClass", name) return stub def _get_path(self, name=None): - if self.option('path'): - directory = self.option('path') + if self.option("path"): + directory = self.option("path") else: - directory = os.path.join(os.getcwd(), 'models') + directory = os.path.join(os.getcwd(), "models") if name: return os.path.join(directory, name) diff --git a/orator/commands/seeds/base_command.py b/orator/commands/seeds/base_command.py index f14646b4..4169466b 100644 --- a/orator/commands/seeds/base_command.py +++ b/orator/commands/seeds/base_command.py @@ -5,6 +5,5 @@ class BaseCommand(Command): - def _get_seeders_path(self): - return os.path.join(os.getcwd(), 'seeds') + return os.path.join(os.getcwd(), "seeds") diff --git a/orator/commands/seeds/make_command.py b/orator/commands/seeds/make_command.py index a3cd6d82..a6d5cc1f 100644 --- a/orator/commands/seeds/make_command.py +++ b/orator/commands/seeds/make_command.py @@ -24,9 +24,9 @@ def handle(self): Executes the command. """ # Making root seeder - self._make('database_seeder', True) + self._make("database_seeder", True) - self._make(self.argument('name')) + self._make(self.argument("name")) def _make(self, name, root=False): name = self._parse_name(name) @@ -34,24 +34,24 @@ def _make(self, name, root=False): path = self._get_path(name) if os.path.exists(path): if not root: - self.error('%s already exists' % name) + self.error("%s already exists" % name) return False self._make_directory(os.path.dirname(path)) - with open(path, 'w') as fh: + with open(path, "w") as fh: fh.write(self._build_class(name)) if root: - with open(os.path.join(os.path.dirname(path), '__init__.py'), 'w'): + with open(os.path.join(os.path.dirname(path), "__init__.py"), "w"): pass - self.info('%s created successfully.' % name) + self.info("%s created successfully." % name) def _parse_name(self, name): - if name.endswith('.py'): - name = name.replace('.py', '', -1) + if name.endswith(".py"): + name = name.replace(".py", "", -1) return name @@ -64,11 +64,11 @@ def _get_path(self, name): :rtype: str """ - path = self.option('path') + path = self.option("path") if path is None: path = self._get_seeders_path() - return os.path.join(path, '%s.py' % name) + return os.path.join(path, "%s.py" % name) def _make_directory(self, path): try: @@ -83,7 +83,7 @@ def _build_class(self, name): stub = self._get_stub() klass = self._get_class_name(name) - stub = stub.replace('DummyClass', klass) + stub = stub.replace("DummyClass", klass) return stub diff --git a/orator/commands/seeds/seed_command.py b/orator/commands/seeds/seed_command.py index 607c4630..9e0ce6c1 100644 --- a/orator/commands/seeds/seed_command.py +++ b/orator/commands/seeds/seed_command.py @@ -18,34 +18,33 @@ class SeedCommand(BaseCommand): {--p|path= : The path to seeders files. Defaults to ./seeds.} {--seeder=database_seeder : The name of the root seeder.} + {--f|force : Force the operation to run.} """ def handle(self): """ Executes the command. """ - confirm = self.confirm( - 'Are you sure you want to seed the database? ', - False - ) - if not confirm: + if not self.confirm_to_proceed( + "Are you sure you want to seed the database?: " + ): return - self.resolver.set_default_connection(self.option('database')) + self.resolver.set_default_connection(self.option("database")) self._get_seeder().run() - self.info('Database seeded!') + self.info("Database seeded!") def _get_seeder(self): - name = self._parse_name(self.option('seeder')) + name = self._parse_name(self.option("seeder")) seeder_file = self._get_path(name) # Loading parent module - load_module('seeds', self._get_path('__init__')) + load_module("seeds", self._get_path("__init__")) # Loading module - mod = load_module('seeds.%s' % name, seeder_file) + mod = load_module("seeds.%s" % name, seeder_file) klass = getattr(mod, inflection.camelize(name)) @@ -56,8 +55,8 @@ def _get_seeder(self): return instance def _parse_name(self, name): - if name.endswith('.py'): - name = name.replace('.py', '', -1) + if name.endswith(".py"): + name = name.replace(".py", "", -1) return name @@ -70,8 +69,8 @@ def _get_path(self, name): :rtype: str """ - path = self.option('path') + path = self.option("path") if path is None: path = self._get_seeders_path() - return os.path.join(path, '%s.py' % name) + return os.path.join(path, "%s.py" % name) diff --git a/orator/connections/connection.py b/orator/connections/connection.py index 14cace24..469b5989 100644 --- a/orator/connections/connection.py +++ b/orator/connections/connection.py @@ -14,14 +14,15 @@ from ..exceptions.query import QueryException -query_logger = logging.getLogger('orator.connection.queries') -connection_logger = logging.getLogger('orator.connection') +query_logger = logging.getLogger("orator.connection.queries") +connection_logger = logging.getLogger("orator.connection") def run(wrapped): """ Special decorator encapsulating query method. """ + @wraps(wrapped) def _run(self, query, bindings=None, *args, **kwargs): self._reconnect_if_missing_connection() @@ -46,8 +47,15 @@ class Connection(ConnectionInterface): name = None - def __init__(self, connection, database='', table_prefix='', config=None, - builder_class=QueryBuilder, builder_default_kwargs=None): + def __init__( + self, + connection, + database="", + table_prefix="", + config=None, + builder_class=QueryBuilder, + builder_default_kwargs=None, + ): """ :param connection: A dbapi connection instance :type connection: Connector @@ -69,7 +77,7 @@ def __init__(self, connection, database='', table_prefix='', config=None, self._database = database if table_prefix is None: - table_prefix = '' + table_prefix = "" self._table_prefix = table_prefix @@ -91,13 +99,13 @@ def __init__(self, connection, database='', table_prefix='', config=None, self._builder_default_kwargs = builder_default_kwargs - self._logging_queries = config.get('log_queries', False) + self._logging_queries = config.get("log_queries", False) self._logged_queries = [] # Setting the marker based on config self._marker = None - if self._config.get('use_qmark'): - self._marker = '?' + if self._config.get("use_qmark"): + self._marker = "?" self._query_grammar = self.get_default_query_grammar() @@ -163,7 +171,9 @@ def query(self): :rtype: QueryBuilder """ query = self._builder_class( - self, self._query_grammar, self._post_processor, + self, + self._query_grammar, + self._post_processor, **self._builder_default_kwargs ) @@ -200,6 +210,34 @@ def select(self, query, bindings=None, use_read_connection=True): return cursor.fetchall() + def select_many( + self, size, query, bindings=None, use_read_connection=True, abort=False + ): + if self.pretending(): + yield [] + else: + bindings = self.prepare_bindings(bindings) + cursor = self._get_cursor_for_select(use_read_connection) + + try: + cursor.execute(query, bindings) + except Exception as e: + if self._caused_by_lost_connection(e) and not abort: + self.reconnect() + + for results in self.select_many( + size, query, bindings, use_read_connection, True + ): + yield results + else: + raise + else: + results = cursor.fetchmany(size) + while results: + yield results + + results = cursor.fetchmany(size) + def _get_cursor_for_select(self, use_read_connection=True): if use_read_connection: self._cursor = self.get_read_connection().cursor() @@ -308,7 +346,9 @@ def pretend(self): self._pretending = False - def _try_again_if_caused_by_lost_connection(self, e, query, bindings, callback, *args, **kwargs): + def _try_again_if_caused_by_lost_connection( + self, e, query, bindings, callback, *args, **kwargs + ): if self._caused_by_lost_connection(e): self.reconnect() @@ -317,18 +357,28 @@ def _try_again_if_caused_by_lost_connection(self, e, query, bindings, callback, raise QueryException(query, bindings, e) def _caused_by_lost_connection(self, e): - message = str(e) - - for s in ['server has gone away', - 'no connection to the server', - 'Lost Connection']: + message = str(e).lower() + + for s in [ + "server has gone away", + "no connection to the server", + "lost connection", + "is dead or not enabled", + "error while sending", + "decryption failed or bad record mac", + "server closed the connection unexpectedly", + "ssl connection has been closed unexpectedly", + "error writing data to the connection", + "connection timed out", + "resource deadlock avoided", + ]: if s in message: return True return False def disconnect(self): - connection_logger.debug('%s is disconnecting' % self.__class__.__name__) + connection_logger.debug("%s is disconnecting" % self.__class__.__name__) if self._connection: self._connection.close() @@ -337,14 +387,14 @@ def disconnect(self): self.set_connection(None).set_read_connection(None) - connection_logger.debug('%s disconnected' % self.__class__.__name__) + connection_logger.debug("%s disconnected" % self.__class__.__name__) def reconnect(self): - connection_logger.debug('%s is reconnecting' % self.__class__.__name__) + connection_logger.debug("%s is reconnecting" % self.__class__.__name__) if self._reconnector is not None and callable(self._reconnector): return self._reconnector(self) - raise Exception('Lost connection and no reconnector available') + raise Exception("Lost connection and no reconnector available") def _reconnect_if_missing_connection(self): if self.get_connection() is None or self.get_read_connection() is None: @@ -360,17 +410,14 @@ def log_query(self, query, bindings, time_=None): query = self._get_cursor_query(query, bindings) if query: - log = 'Executed %s' % (query,) + log = "Executed %s" % (query,) if time_: - log += ' in %sms' % time_ + log += " in %sms" % time_ - query_logger.debug(log, - extra={ - 'query': query, - 'bindings': bindings, - 'elapsed_time': time_ - }) + query_logger.debug( + log, extra={"query": query, "bindings": bindings, "elapsed_time": time_} + ) def _get_elapsed_time(self, start): return round((time.time() - start) * 1000, 2) @@ -398,8 +445,9 @@ def get_read_connection(self): def set_connection(self, connection): if self._transactions >= 1: - raise RuntimeError("Can't swap dbapi connection" - "while within transaction.") + raise RuntimeError( + "Can't swap dbapi connection" "while within transaction." + ) self._connection = connection @@ -416,7 +464,7 @@ def set_reconnector(self, reconnector): return self def get_name(self): - return self._config.get('name') + return self._config.get("name") def get_config(self, option): return self._config.get(option) @@ -502,6 +550,7 @@ def set_builder_class(self, klass, default_kwargs=None): return self def __enter__(self): + self._reconnect_if_missing_connection() self.begin_transaction() return self diff --git a/orator/connections/connection_interface.py b/orator/connections/connection_interface.py index bbce1d86..b0f7d30d 100644 --- a/orator/connections/connection_interface.py +++ b/orator/connections/connection_interface.py @@ -2,7 +2,6 @@ class ConnectionInterface(object): - def table(self, table): """ Begin a fluent query against a database table @@ -170,5 +169,3 @@ def transaction_level(self): def pretend(self): raise NotImplementedError() - - diff --git a/orator/connections/connection_resolver_interface.py b/orator/connections/connection_resolver_interface.py index d9329b90..d28987f4 100644 --- a/orator/connections/connection_resolver_interface.py +++ b/orator/connections/connection_resolver_interface.py @@ -2,7 +2,6 @@ class ConnectionResolverInterface(object): - def connection(self, name=None): raise NotImplementedError() diff --git a/orator/connections/mysql_connection.py b/orator/connections/mysql_connection.py index 3fe0beb7..57558a7e 100644 --- a/orator/connections/mysql_connection.py +++ b/orator/connections/mysql_connection.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- +from ..utils import decode from ..utils import PY2 from .connection import Connection from ..query.grammars.mysql_grammar import MySQLQueryGrammar @@ -11,7 +12,7 @@ class MySQLConnection(Connection): - name = 'mysql' + name = "mysql" def get_default_query_grammar(self): return MySQLQueryGrammar(marker=self._marker) @@ -37,7 +38,16 @@ def get_schema_manager(self): return MySQLSchemaManager(self) def begin_transaction(self): - self._connection.autocommit(False) + self._reconnect_if_missing_connection() + + try: + self._connection.autocommit(False) + except Exception as e: + if self._caused_by_lost_connection(e): + self.reconnect() + self._connection.autocommit(False) + else: + raise super(MySQLConnection, self).begin_transaction() @@ -58,10 +68,10 @@ def rollback(self): self._transactions -= 1 def _get_cursor_query(self, query, bindings): - if not hasattr(self._cursor, '_last_executed') or self._pretending: + if not hasattr(self._cursor, "_last_executed") or self._pretending: return super(MySQLConnection, self)._get_cursor_query(query, bindings) if PY2: - return self._cursor._last_executed.decode() + return decode(self._cursor._last_executed) return self._cursor._last_executed diff --git a/orator/connections/postgres_connection.py b/orator/connections/postgres_connection.py index b5e44ce2..b9617978 100644 --- a/orator/connections/postgres_connection.py +++ b/orator/connections/postgres_connection.py @@ -11,7 +11,7 @@ class PostgresConnection(Connection): - name = 'pgsql' + name = "pgsql" def get_default_query_grammar(self): return PostgresQueryGrammar(marker=self._marker) @@ -64,7 +64,7 @@ def _get_cursor_query(self, query, bindings): return self._cursor.mogrify(query, bindings).decode() - if not hasattr(self._cursor, 'query'): + if not hasattr(self._cursor, "query"): return super(PostgresConnection, self)._get_cursor_query(query, bindings) if PY2: diff --git a/orator/connections/sqlite_connection.py b/orator/connections/sqlite_connection.py index 060c6efc..4f7b4e62 100644 --- a/orator/connections/sqlite_connection.py +++ b/orator/connections/sqlite_connection.py @@ -10,7 +10,7 @@ class SQLiteConnection(Connection): - name = 'sqlite' + name = "sqlite" def get_default_query_grammar(self): return self.with_table_prefix(SQLiteQueryGrammar()) @@ -25,7 +25,7 @@ def get_schema_manager(self): return SQLiteSchemaManager(self) def begin_transaction(self): - self._connection.isolation_level = 'DEFERRED' + self._connection.isolation_level = "DEFERRED" super(SQLiteConnection, self).begin_transaction() diff --git a/orator/connectors/connection_factory.py b/orator/connectors/connection_factory.py index c0fe3be8..49ee29e9 100644 --- a/orator/connectors/connection_factory.py +++ b/orator/connectors/connection_factory.py @@ -6,31 +6,27 @@ from .mysql_connector import MySQLConnector from .postgres_connector import PostgresConnector from .sqlite_connector import SQLiteConnector -from ..connections import ( - MySQLConnection, - PostgresConnection, - SQLiteConnection -) +from ..connections import MySQLConnection, PostgresConnection, SQLiteConnection class ConnectionFactory(object): CONNECTORS = { - 'sqlite': SQLiteConnector, - 'mysql': MySQLConnector, - 'postgres': PostgresConnector, - 'pgsql': PostgresConnector + "sqlite": SQLiteConnector, + "mysql": MySQLConnector, + "postgres": PostgresConnector, + "pgsql": PostgresConnector, } CONNECTIONS = { - 'sqlite': SQLiteConnection, - 'mysql': MySQLConnection, - 'postgres': PostgresConnection, - 'pgsql': PostgresConnection + "sqlite": SQLiteConnection, + "mysql": MySQLConnection, + "postgres": PostgresConnection, + "pgsql": PostgresConnection, } def make(self, config, name=None): - if 'read' in config: + if "read" in config: return self._create_read_write_connection(config) return self._create_single_connection(config) @@ -39,11 +35,7 @@ def _create_single_connection(self, config): conn = self.create_connector(config).connect(config) return self._create_connection( - config['driver'], - conn, - config['database'], - config.get('prefix', ''), - config + config["driver"], conn, config["database"], config.get("prefix", ""), config ) def _create_read_write_connection(self, config): @@ -59,12 +51,12 @@ def _create_read_connection(self, config): return self.create_connector(read_config).connect(read_config) def _get_read_config(self, config): - read_config = self._get_read_write_config(config, 'read') + read_config = self._get_read_write_config(config, "read") return self._merge_read_write_config(config, read_config) def _get_write_config(self, config): - write_config = self._get_read_write_config(config, 'write') + write_config = self._get_read_write_config(config, "write") return self._merge_read_write_config(config, write_config) @@ -78,16 +70,16 @@ def _merge_read_write_config(self, config, merge): config = config.copy() config.update(merge) - del config['read'] - del config['write'] + del config["read"] + del config["write"] return config def create_connector(self, config): - if 'driver' not in config: - raise ArgumentError('A driver must be specified') + if "driver" not in config: + raise ArgumentError("A driver must be specified") - driver = config['driver'] + driver = config["driver"] if driver not in self.CONNECTORS: raise UnsupportedDriver(driver) @@ -102,7 +94,7 @@ def register_connector(cls, name, connector): def register_connection(cls, name, connection): cls.CONNECTIONS[name] = connection - def _create_connection(self, driver, connection, database, prefix='', config=None): + def _create_connection(self, driver, connection, database, prefix="", config=None): if config is None: config = {} diff --git a/orator/connectors/connector.py b/orator/connectors/connector.py index babd3ba1..60916adb 100644 --- a/orator/connectors/connector.py +++ b/orator/connectors/connector.py @@ -6,9 +6,7 @@ class Connector(object): - RESERVED_KEYWORDS = [ - 'log_queries', 'driver', 'prefix', 'name' - ] + RESERVED_KEYWORDS = ["log_queries", "driver", "prefix", "name"] SUPPORTED_PACKAGES = [] @@ -47,16 +45,16 @@ def get_params(self): return self._params def get_database(self): - return self._params.get('database') + return self._params.get("database") def get_host(self): - return self._params.get('host') + return self._params.get("host") def get_user(self): - return self._params.get('user') + return self._params.get("user") def get_password(self): - return self._params.get('password') + return self._params.get("password") def get_database_platform(self): if self._platform is None: diff --git a/orator/connectors/mysql_connector.py b/orator/connectors/mysql_connector.py index 580158c6..858493d7 100644 --- a/orator/connectors/mysql_connector.py +++ b/orator/connectors/mysql_connector.py @@ -1,29 +1,32 @@ # -*- coding: utf-8 -*- import re -from pendulum import Pendulum +from pendulum import Pendulum, Date try: import MySQLdb as mysql # Fix for understanding Pendulum object import MySQLdb.converters + MySQLdb.converters.conversions[Pendulum] = MySQLdb.converters.DateTime2literal + MySQLdb.converters.conversions[Date] = MySQLdb.converters.Thing2Literal from MySQLdb.cursors import DictCursor as cursor_class - keys_fix = { - 'password': 'passwd', - 'database': 'db' - } + + keys_fix = {"password": "passwd", "database": "db"} except ImportError as e: try: import pymysql as mysql # Fix for understanding Pendulum object import pymysql.converters + pymysql.converters.conversions[Pendulum] = pymysql.converters.escape_datetime + pymysql.converters.conversions[Date] = pymysql.converters.escape_date from pymysql.cursors import DictCursor as cursor_class + keys_fix = {} except ImportError as e: mysql = None @@ -32,19 +35,21 @@ from ..dbal.platforms import MySQLPlatform, MySQL57Platform from .connector import Connector from ..utils.qmarker import qmark, denullify +from ..utils.helpers import serialize class Record(dict): - def __getattr__(self, item): try: return self[item] except KeyError: raise AttributeError(item) + def serialize(self): + return serialize(self) -class BaseDictCursor(cursor_class): +class BaseDictCursor(cursor_class): def _fetch_row(self, size=1): # Overridden for mysqclient if not self._result: @@ -52,14 +57,13 @@ def _fetch_row(self, size=1): rows = self._result.fetch_row(size, self._fetch_type) return tuple(Record(r) for r in rows) - + def _conv_row(self, row): # Overridden for pymysql return Record(super(BaseDictCursor, self)._conv_row(row)) class DictCursor(BaseDictCursor): - def execute(self, query, args=None): query = qmark(query) @@ -68,20 +72,22 @@ def execute(self, query, args=None): def executemany(self, query, args): query = qmark(query) - return super(DictCursor, self).executemany( - query, denullify(args) - ) + return super(DictCursor, self).executemany(query, denullify(args)) class MySQLConnector(Connector): RESERVED_KEYWORDS = [ - 'log_queries', 'driver', 'prefix', - 'engine', 'collation', - 'name', 'use_qmark' + "log_queries", + "driver", + "prefix", + "engine", + "collation", + "name", + "use_qmark", ] - SUPPORTED_PACKAGES = ['PyMySQL', 'mysqlclient'] + SUPPORTED_PACKAGES = ["PyMySQL", "mysqlclient"] def _do_connect(self, config): config = dict(config.items()) @@ -89,19 +95,16 @@ def _do_connect(self, config): config[value] = config[key] del config[key] - config['autocommit'] = True - config['cursorclass'] = self.get_cursor_class(config) + config["autocommit"] = True + config["cursorclass"] = self.get_cursor_class(config) return self.get_api().connect(**self.get_config(config)) def get_default_config(self): - return { - 'charset': 'utf8', - 'use_unicode': True - } + return {"charset": "utf8", "use_unicode": True} def get_cursor_class(self, config): - if config.get('use_qmark'): + if config.get("use_qmark"): return DictCursor return BaseDictCursor @@ -112,25 +115,27 @@ def get_api(self): def get_server_version(self): version = self._connection.get_server_info() - version_parts = re.match('^(?P\d+)(?:\.(?P\d+)(?:\.(?P\d+))?)?', version) + version_parts = re.match( + "^(?P\d+)(?:\.(?P\d+)(?:\.(?P\d+))?)?", version + ) - major = int(version_parts.group('major')) - minor = version_parts.group('minor') or 0 - patch = version_parts.group('patch') or 0 + major = int(version_parts.group("major")) + minor = version_parts.group("minor") or 0 + patch = version_parts.group("patch") or 0 minor, patch = int(minor), int(patch) - server_version = (major, minor, patch, '') + server_version = (major, minor, patch, "") - if 'mariadb' in version.lower(): - server_version = (major, minor, patch, 'mariadb') + if "mariadb" in version.lower(): + server_version = (major, minor, patch, "mariadb") return server_version def _create_database_platform_for_version(self, version): major, minor, _, extra = version - if extra == 'mariadb': + if extra == "mariadb": return self.get_dbal_platform() if (major, minor) >= (5, 7): diff --git a/orator/connectors/postgres_connector.py b/orator/connectors/postgres_connector.py index 4f635ebb..fd28f737 100644 --- a/orator/connectors/postgres_connector.py +++ b/orator/connectors/postgres_connector.py @@ -1,7 +1,5 @@ # -*- coding: utf-8 -*- -import re - try: import psycopg2 import psycopg2.extras @@ -20,34 +18,31 @@ from ..dbal.platforms import PostgresPlatform from .connector import Connector from ..utils.qmarker import qmark, denullify +from ..utils.helpers import serialize class BaseDictConnection(connection_class): - def cursor(self, *args, **kwargs): - kwargs.setdefault('cursor_factory', BaseDictCursor) + kwargs.setdefault("cursor_factory", BaseDictCursor) return super(BaseDictConnection, self).cursor(*args, **kwargs) class DictConnection(BaseDictConnection): - def cursor(self, *args, **kwargs): - kwargs.setdefault('cursor_factory', DictCursor) + kwargs.setdefault("cursor_factory", DictCursor) return super(DictConnection, self).cursor(*args, **kwargs) class BaseDictCursor(cursor_class): - def __init__(self, *args, **kwargs): - kwargs['row_factory'] = DictRow + kwargs["row_factory"] = DictRow super(cursor_class, self).__init__(*args, **kwargs) self._prefetch = 1 class DictCursor(BaseDictCursor): - def execute(self, query, vars=None): query = qmark(query) @@ -56,27 +51,36 @@ def execute(self, query, vars=None): def executemany(self, query, args_seq): query = qmark(query) - return super(DictCursor, self).executemany( - query, denullify(args_seq)) + return super(DictCursor, self).executemany(query, denullify(args_seq)) class DictRow(row_class): - def __getattr__(self, item): try: return self[item] except KeyError: raise AttributeError(item) + def serialize(self): + serialized = {} + for column, index in self._index.items(): + serialized[column] = list.__getitem__(self, index) + + return serialize(serialized) + class PostgresConnector(Connector): RESERVED_KEYWORDS = [ - 'log_queries', 'driver', 'prefix', 'name', - 'register_unicode', 'use_qmark' + "log_queries", + "driver", + "prefix", + "name", + "register_unicode", + "use_qmark", ] - SUPPORTED_PACKAGES = ['psycopg2'] + SUPPORTED_PACKAGES = ["psycopg2"] def _do_connect(self, config): connection = self.get_api().connect( @@ -84,7 +88,7 @@ def _do_connect(self, config): **self.get_config(config) ) - if config.get('use_unicode', True): + if config.get("use_unicode", True): extensions.register_type(extensions.UNICODE, connection) extensions.register_type(extensions.UNICODEARRAY, connection) @@ -93,7 +97,7 @@ def _do_connect(self, config): return connection def get_connection_class(self, config): - if config.get('use_qmark'): + if config.get("use_qmark"): return DictConnection return BaseDictConnection @@ -101,6 +105,14 @@ def get_connection_class(self, config): def get_api(self): return psycopg2 + @property + def autocommit(self): + return self._connection.autocommit + + @autocommit.setter + def autocommit(self, value): + self._connection.autocommit = value + def get_dbal_platform(self): return PostgresPlatform() @@ -113,4 +125,4 @@ def get_server_version(self): minor = int_version // 100 % 100 fix = int_version % 10 - return major, minor, fix, '' + return major, minor, fix, "" diff --git a/orator/connectors/sqlite_connector.py b/orator/connectors/sqlite_connector.py index 9d18fb7d..2b165f55 100644 --- a/orator/connectors/sqlite_connector.py +++ b/orator/connectors/sqlite_connector.py @@ -1,22 +1,23 @@ # -*- coding: utf-8 -*- -from pendulum import Pendulum +from pendulum import Pendulum, Date try: import sqlite3 from sqlite3 import register_adapter - register_adapter(Pendulum, lambda val: val.isoformat(' ')) + register_adapter(Pendulum, lambda val: val.isoformat(" ")) + register_adapter(Date, lambda val: val.isoformat()) except ImportError: sqlite3 = None from ..dbal.platforms import SQLitePlatform +from ..utils.helpers import serialize from .connector import Connector -class DictCursor(object): - +class DictCursor(dict): def __init__(self, cursor, row): self.dict = {} self.cursor = cursor @@ -24,30 +25,27 @@ def __init__(self, cursor, row): for idx, col in enumerate(cursor.description): self.dict[col[0]] = row[idx] + super(DictCursor, self).__init__(self.dict) + def __getattr__(self, item): try: return self[item] except KeyError: return getattr(self.cursor, item) - def __getitem__(self, item): - return self.dict[item] - - def keys(self): - return self.dict.keys() - - def values(self): - return self.dict.values() - - def items(self): - return self.dict.items() + def serialize(self): + return serialize(self) class SQLiteConnector(Connector): RESERVED_KEYWORDS = [ - 'log_queries', 'driver', 'prefix', 'name', - 'foreign_keys', 'use_qmark' + "log_queries", + "driver", + "prefix", + "name", + "foreign_keys", + "use_qmark", ] def _do_connect(self, config): @@ -56,7 +54,7 @@ def _do_connect(self, config): connection.row_factory = DictCursor # We activate foreign keys support by default - if config.get('foreign_keys', True): + if config.get("foreign_keys", True): connection.execute("PRAGMA foreign_keys = ON") return connection @@ -64,6 +62,14 @@ def _do_connect(self, config): def get_api(self): return sqlite3 + @property + def isolation_level(self): + return self._connection.isolation_level + + @isolation_level.setter + def isolation_level(self, value): + self._connection.isolation_level = value + def get_dbal_platform(self): return SQLitePlatform() @@ -71,9 +77,9 @@ def is_version_aware(self): return False def get_server_version(self): - sql = 'select sqlite_version() AS sqlite_version' + sql = "select sqlite_version() AS sqlite_version" rows = self._connection.execute(sql).fetchall() - version = rows[0]['sqlite_version'] + version = rows[0]["sqlite_version"] - return tuple(version.split('.')[:3] + ['']) + return tuple(version.split(".")[:3] + [""]) diff --git a/orator/database_manager.py b/orator/database_manager.py index 2c0b8bd7..e304d755 100644 --- a/orator/database_manager.py +++ b/orator/database_manager.py @@ -6,11 +6,10 @@ from .connectors.connection_factory import ConnectionFactory from .exceptions import ArgumentError -logger = logging.getLogger('orator.database_manager') +logger = logging.getLogger("orator.database_manager") class BaseDatabaseManager(ConnectionResolverInterface): - def __init__(self, config, factory=ConnectionFactory()): """ :param config: The connections configuration @@ -39,7 +38,7 @@ def connection(self, name=None): name, type = self._parse_connection_name(name) if name not in self._connections: - logger.debug('Initiating connection %s' % name) + logger.debug("Initiating connection %s" % name) connection = self._make_connection(name) self._set_connection_for_type(connection, type) @@ -61,8 +60,8 @@ def _parse_connection_name(self, name): if name is None: name = self.get_default_connection() - if name.endswith(('::read', '::write')): - return name.split('::', 1) + if name.endswith(("::read", "::write")): + return name.split("::", 1) return name, None @@ -87,7 +86,7 @@ def disconnect(self, name=None): if name is None: name = self.get_default_connection() - logger.debug('Disconnecting %s' % name) + logger.debug("Disconnecting %s" % name) if name in self._connections: self._connections[name].disconnect() @@ -96,7 +95,7 @@ def reconnect(self, name=None): if name is None: name = self.get_default_connection() - logger.debug('Reconnecting %s' % name) + logger.debug("Reconnecting %s" % name) self.disconnect(name) @@ -106,25 +105,27 @@ def reconnect(self, name=None): return self._refresh_api_connections(name) def _refresh_api_connections(self, name): - logger.debug('Refreshing api connections for %s' % name) + logger.debug("Refreshing api connections for %s" % name) fresh = self._make_connection(name) - return self._connections[name]\ - .set_connection(fresh.get_connection())\ + return ( + self._connections[name] + .set_connection(fresh.get_connection()) .set_read_connection(fresh.get_read_connection()) + ) def _make_connection(self, name): - logger.debug('Making connection for %s' % name) + logger.debug("Making connection for %s" % name) config = self._get_config(name) - if 'name' not in config: - config['name'] = name + if "name" not in config: + config["name"] = name if name in self._extensions: return self._extensions[name](config, name) - driver = config['driver'] + driver = config["driver"] if driver in self._extensions: return self._extensions[driver](config, name) @@ -132,7 +133,7 @@ def _make_connection(self, name): return self._factory.make(config, name) def _prepare(self, connection): - logger.debug('Preparing connection %s' % connection.get_name()) + logger.debug("Preparing connection %s" % connection.get_name()) def reconnector(connection_): self.reconnect(connection_.get_name()) @@ -142,9 +143,9 @@ def reconnector(connection_): return connection def _set_connection_for_type(self, connection, type): - if type == 'read': + if type == "read": connection.set_connection(connection.get_read_api()) - elif type == 'write': + elif type == "write": connection.set_read_connection(connection.get_api()) return connection @@ -157,7 +158,7 @@ def _get_config(self, name): config = connections.get(name) if not config: - raise ArgumentError('Database [%s] not configured' % name) + raise ArgumentError("Database [%s] not configured" % name) return config @@ -165,11 +166,11 @@ def get_default_connection(self): if len(self._config) == 1: return list(self._config.keys())[0] - return self._config['default'] + return self._config["default"] def set_default_connection(self, name): if name is not None: - self._config['default'] = name + self._config["default"] = name def extend(self, name, resolver): self._extensions[name] = resolver diff --git a/orator/dbal/__init__.py b/orator/dbal/__init__.py index 633f8661..40a96afc 100644 --- a/orator/dbal/__init__.py +++ b/orator/dbal/__init__.py @@ -1,2 +1 @@ # -*- coding: utf-8 -*- - diff --git a/orator/dbal/abstract_asset.py b/orator/dbal/abstract_asset.py index 2fe1f3a0..06287e1f 100644 --- a/orator/dbal/abstract_asset.py +++ b/orator/dbal/abstract_asset.py @@ -24,8 +24,8 @@ def _set_name(self, name): self._quoted = True name = self._trim_quotes(name) - if '.' in name: - parts = name.split('.', 1) + if "." in name: + parts = name.split(".", 1) self._namespace = parts[0] name = parts[1] @@ -47,7 +47,7 @@ def get_shortest_name(self, default_namespace): def get_full_qualified_name(self, default_namespace): name = self.get_name() if not self._namespace: - name = default_namespace + '.' + name + name = default_namespace + "." + name return name.lower() @@ -55,33 +55,34 @@ def is_quoted(self): return self._quoted def _is_identifier_quoted(self, identifier): - return len(identifier) > 0\ - and (identifier[0] == '`' or identifier[0] == '"' or identifier[0] == '[') + return len(identifier) > 0 and ( + identifier[0] == "`" or identifier[0] == '"' or identifier[0] == "[" + ) def _trim_quotes(self, identifier): - return re.sub('[`"\[\]]', '', identifier) + return re.sub('[`"\[\]]', "", identifier) def get_name(self): if self._namespace: - return self._namespace + '.' + self._name + return self._namespace + "." + self._name return self._name def get_quoted_name(self, platform): keywords = platform.get_reserved_keywords_list() - parts = self.get_name().split('.') + parts = self.get_name().split(".") for k, v in enumerate(parts): if self._quoted or keywords.is_keyword(v): parts[k] = platform.quote_identifier(v) - return '.'.join(parts) + return ".".join(parts) - def _generate_identifier_name(self, columns, prefix='', max_size=30): + def _generate_identifier_name(self, columns, prefix="", max_size=30): """ Generates an identifier from a list of column names obeying a certain string length. """ - hash = '' + hash = "" for column in columns: - hash += '%x' % binascii.crc32(encode(str(column))) + hash += "%x" % binascii.crc32(encode(str(column))) - return (prefix + '_' + hash)[:max_size] + return (prefix + "_" + hash)[:max_size] diff --git a/orator/dbal/column.py b/orator/dbal/column.py index ce970af6..59ce0b3b 100644 --- a/orator/dbal/column.py +++ b/orator/dbal/column.py @@ -5,7 +5,6 @@ class Column(AbstractAsset): - def __init__(self, name, type, options=None): self._set_name(name) self._type = type @@ -25,7 +24,7 @@ def __init__(self, name, type, options=None): def set_options(self, options): for key, value in options.items(): - method = 'set_%s' % key + method = "set_%s" % key if hasattr(self, method): getattr(self, method)(value) @@ -59,7 +58,11 @@ def set_length(self, length): return self def set_precision(self, precision): - if precision is None or isinstance(precision, basestring) and not precision.isdigit(): + if ( + precision is None + or isinstance(precision, basestring) + and not precision.isdigit() + ): precision = 10 self._precision = int(precision) @@ -123,21 +126,19 @@ def get_default(self): def to_dict(self): d = { - 'name': self._name, - 'type': self._type, - 'default': self._default, - 'notnull': self._notnull, - 'length': self._length, - 'precision': self._precision, - 'scale': self._scale, - 'fixed': self._fixed, - 'unsigned': self._unsigned, - 'autoincrement': self._autoincrement, - 'extra': self._extra + "name": self._name, + "type": self._type, + "default": self._default, + "notnull": self._notnull, + "length": self._length, + "precision": self._precision, + "scale": self._scale, + "fixed": self._fixed, + "unsigned": self._unsigned, + "autoincrement": self._autoincrement, + "extra": self._extra, } d.update(self._platform_options) return d - - diff --git a/orator/dbal/column_diff.py b/orator/dbal/column_diff.py index a3a3a796..4ad0da67 100644 --- a/orator/dbal/column_diff.py +++ b/orator/dbal/column_diff.py @@ -4,8 +4,9 @@ class ColumnDiff(object): - - def __init__(self, old_column_name, column, changed_properties=None, from_column=None): + def __init__( + self, old_column_name, column, changed_properties=None, from_column=None + ): self.old_column_name = old_column_name self.column = column self.changed_properties = changed_properties diff --git a/orator/dbal/comparator.py b/orator/dbal/comparator.py index 838dbf5a..4c889528 100644 --- a/orator/dbal/comparator.py +++ b/orator/dbal/comparator.py @@ -42,12 +42,16 @@ def diff_table(self, table1, table2): continue # See if column has changed properties in table2 - changed_properties = self.diff_column(column, table2.get_column(column_name)) + changed_properties = self.diff_column( + column, table2.get_column(column_name) + ) if changed_properties: - column_diff = ColumnDiff(column.get_name(), - table2.get_column(column_name), - changed_properties) + column_diff = ColumnDiff( + column.get_name(), + table2.get_column(column_name), + changed_properties, + ) column_diff.from_column = column table_differences.changed_columns[column.get_name()] = column_diff changes += 1 @@ -59,7 +63,9 @@ def diff_table(self, table1, table2): # See if all the fields in table1 exist in table2 for index_name, index in table2_indexes.items(): - if (index.is_primary() and not table1.has_primary_key()) or table1.has_index(index_name): + if ( + index.is_primary() and not table1.has_primary_key() + ) or table1.has_index(index_name): continue table_differences.added_indexes[index_name] = index @@ -67,8 +73,9 @@ def diff_table(self, table1, table2): # See if there are any removed fields in table2 for index_name, index in table1_indexes.items(): - if (index.is_primary() and not table2.has_primary_key())\ - or (not index.is_primary() and not table2.has_index(index_name)): + if (index.is_primary() and not table2.has_primary_key()) or ( + not index.is_primary() and not table2.has_index(index_name) + ): table_differences.removed_indexes[index_name] = index changes += 1 continue @@ -127,7 +134,9 @@ def detect_column_renamings(self, table_differences): if added_column.get_name() not in rename_candidates: rename_candidates[added_column.get_name()] = [] - rename_candidates[added_column.get_name()].append((removed_column, added_column, added_column_name)) + rename_candidates[added_column.get_name()].append( + (removed_column, added_column, added_column_name) + ) for candidate_columns in rename_candidates.values(): if len(candidate_columns) == 1: @@ -136,7 +145,9 @@ def detect_column_renamings(self, table_differences): added_column_name = added_column.get_name().lower() if removed_column_name not in table_differences.renamed_columns: - table_differences.renamed_columns[removed_column_name] = added_column + table_differences.renamed_columns[ + removed_column_name + ] = added_column del table_differences.added_columns[added_column_name] del table_differences.removed_columns[removed_column_name] @@ -159,7 +170,9 @@ def detect_index_renamings(self, table_differences): if added_index.get_name() not in rename_candidates: rename_candidates[added_index.get_name()] = [] - rename_candidates[added_index.get_name()].append((removed_index, added_index, added_index_name)) + rename_candidates[added_index.get_name()].append( + (removed_index, added_index, added_index_name) + ) for candidate_indexes in rename_candidates.values(): # If the current rename candidate contains exactly one semantically equal index, @@ -185,19 +198,30 @@ def diff_foreign_key(self, key1, key2): :rtype: bool """ - key1_unquoted_local_columns = [c.lower() for c in key1.get_unquoted_local_columns()] - key2_unquoted_local_columns = [c.lower() for c in key2.get_unquoted_local_columns()] + key1_unquoted_local_columns = [ + c.lower() for c in key1.get_unquoted_local_columns() + ] + key2_unquoted_local_columns = [ + c.lower() for c in key2.get_unquoted_local_columns() + ] if key1_unquoted_local_columns != key2_unquoted_local_columns: return True - key1_unquoted_foreign_columns = [c.lower() for c in key1.get_unquoted_foreign_columns()] - key2_unquoted_foreign_columns = [c.lower() for c in key2.get_unquoted_foreign_columns()] + key1_unquoted_foreign_columns = [ + c.lower() for c in key1.get_unquoted_foreign_columns() + ] + key2_unquoted_foreign_columns = [ + c.lower() for c in key2.get_unquoted_foreign_columns() + ] if key1_unquoted_foreign_columns != key2_unquoted_foreign_columns: return True - if key1.get_unqualified_foreign_table_name() != key2.get_unqualified_foreign_table_name(): + if ( + key1.get_unqualified_foreign_table_name() + != key2.get_unqualified_foreign_table_name() + ): return True if key1.on_update() != key2.on_update(): @@ -222,34 +246,39 @@ def diff_column(self, column1, column2): changed_properties = [] - for prop in ['type', 'notnull', 'unsigned', 'autoincrement']: + for prop in ["type", "notnull", "unsigned", "autoincrement"]: if properties1[prop] != properties2[prop]: changed_properties.append(prop) - if properties1['default'] != properties2['default']\ - or (properties1['default'] is None and properties2['default'] is not None)\ - or (properties2['default'] is None and properties1['default'] is not None): - changed_properties.append('default') - - if properties1['type'] == 'string' and properties1['type'] != 'guid'\ - or properties1['type'] in ['binary', 'blob']: - length1 = properties1['length'] or 255 - length2 = properties2['length'] or 255 + if ( + properties1["default"] != properties2["default"] + or (properties1["default"] is None and properties2["default"] is not None) + or (properties2["default"] is None and properties1["default"] is not None) + ): + changed_properties.append("default") + + if ( + properties1["type"] == "string" + and properties1["type"] != "guid" + or properties1["type"] in ["binary", "blob"] + ): + length1 = properties1["length"] or 255 + length2 = properties2["length"] or 255 if length1 != length2: - changed_properties.append('length') + changed_properties.append("length") - if properties1['fixed'] != properties2['fixed']: - changed_properties.append('fixed') - elif properties1['type'] in ['decimal', 'float', 'double precision']: - precision1 = properties1['precision'] or 10 - precision2 = properties2['precision'] or 10 + if properties1["fixed"] != properties2["fixed"]: + changed_properties.append("fixed") + elif properties1["type"] in ["decimal", "float", "double precision"]: + precision1 = properties1["precision"] or 10 + precision2 = properties2["precision"] or 10 if precision1 != precision2: - changed_properties.append('precision') + changed_properties.append("precision") - if properties1['scale'] != properties2['scale']: - changed_properties.append('scale') + if properties1["scale"] != properties2["scale"]: + changed_properties.append("scale") return list(set(changed_properties)) diff --git a/orator/dbal/exceptions/__init__.py b/orator/dbal/exceptions/__init__.py index 191a411d..6a0d8270 100644 --- a/orator/dbal/exceptions/__init__.py +++ b/orator/dbal/exceptions/__init__.py @@ -7,7 +7,6 @@ class DBALException(Exception): class InvalidPlatformSpecified(DBALException): - def __init__(self, index_name, table_name): message = 'Invalid "platform" option specified, need to give an instance of dbal.platforms.Platform' @@ -20,7 +19,6 @@ class SchemaException(DBALException): class IndexDoesNotExist(SchemaException): - def __init__(self, index_name, table_name): message = 'Index "%s" does not exist on table "%s".' % (index_name, table_name) @@ -28,15 +26,16 @@ def __init__(self, index_name, table_name): class IndexAlreadyExists(SchemaException): - def __init__(self, index_name, table_name): - message = 'An index with name "%s" already exists on table "%s".' % (index_name, table_name) + message = 'An index with name "%s" already exists on table "%s".' % ( + index_name, + table_name, + ) super(IndexAlreadyExists, self).__init__(message) class IndexNameInvalid(SchemaException): - def __init__(self, index_name): message = 'Invalid index name "%s" given, has to be [a-zA-Z0-9_]' % index_name @@ -44,7 +43,6 @@ def __init__(self, index_name): class ColumnDoesNotExist(SchemaException): - def __init__(self, column, table_name): message = 'Column "%s" does not exist on table "%s".' % (column, table_name) @@ -52,16 +50,20 @@ def __init__(self, column, table_name): class ColumnAlreadyExists(SchemaException): - def __init__(self, column, table_name): - message = 'An column with name "%s" already exists on table "%s".' % (column, table_name) + message = 'An column with name "%s" already exists on table "%s".' % ( + column, + table_name, + ) super(ColumnAlreadyExists, self).__init__(message) class ForeignKeyDoesNotExist(SchemaException): - def __init__(self, constraint, table_name): - message = 'Foreign key "%s" does not exist on table "%s".' % (constraint, table_name) + message = 'Foreign key "%s" does not exist on table "%s".' % ( + constraint, + table_name, + ) super(ForeignKeyDoesNotExist, self).__init__(message) diff --git a/orator/dbal/foreign_key_constraint.py b/orator/dbal/foreign_key_constraint.py index 192c99bb..52987bbd 100644 --- a/orator/dbal/foreign_key_constraint.py +++ b/orator/dbal/foreign_key_constraint.py @@ -10,9 +10,14 @@ class ForeignKeyConstraint(AbstractAsset): An abstraction class for a foreign key constraint. """ - def __init__(self, local_column_names, - foreign_table_name, foreign_column_names, - name=None, options=None): + def __init__( + self, + local_column_names, + foreign_table_name, + foreign_column_names, + name=None, + options=None, + ): """ Constructor. @@ -152,7 +157,7 @@ def get_unqualified_foreign_table_name(self): :rtype: str """ - parts = self.get_foreign_table_name().split('.') + parts = self.get_foreign_table_name().split(".") return parts[-1].lower() @@ -227,7 +232,7 @@ def on_update(self): :rtype: str or None """ - return self._on_event('on_update') + return self._on_event("on_update") def on_delete(self): """ @@ -236,7 +241,7 @@ def on_delete(self): :rtype: str or None """ - return self._on_event('on_delete') + return self._on_event("on_delete") def _on_event(self, event): """ @@ -251,7 +256,7 @@ def _on_event(self, event): if self.has_option(event): on_event = self.get_option(event).upper() - if on_event not in ['NO ACTION', 'RESTRICT']: + if on_event not in ["NO ACTION", "RESTRICT"]: return on_event return False diff --git a/orator/dbal/identifier.py b/orator/dbal/identifier.py index 3fc779a9..0ffbfcda 100644 --- a/orator/dbal/identifier.py +++ b/orator/dbal/identifier.py @@ -4,6 +4,5 @@ class Identifier(AbstractAsset): - def __init__(self, identifier): self._set_name(identifier) diff --git a/orator/dbal/index.py b/orator/dbal/index.py index b3616d91..2dbb6690 100644 --- a/orator/dbal/index.py +++ b/orator/dbal/index.py @@ -10,7 +10,9 @@ class Index(AbstractAsset): An abstraction class for an index. """ - def __init__(self, name, columns, is_unique=False, is_primary=False, flags=None, options=None): + def __init__( + self, name, columns, is_unique=False, is_primary=False, flags=None, options=None + ): """ Constructor. @@ -124,7 +126,9 @@ def spans_columns(self, column_names): for i in range(number_of_columns): column = self._trim_quotes(columns[i].lower()) - if i >= len(column_names) or column != self._trim_quotes(column_names[i].lower()): + if i >= len(column_names) or column != self._trim_quotes( + column_names[i].lower() + ): same_columns = False return same_columns @@ -176,11 +180,14 @@ def same_partial_index(self, other): :rtype: bool """ - if (self.has_option('where') and other.has_option('where') - and self.get_option('where') == other.get_option('where')): + if ( + self.has_option("where") + and other.has_option("where") + and self.get_option("where") == other.get_option("where") + ): return True - if not self.has_option('where') and not other.has_option('where'): + if not self.has_option("where") and not other.has_option("where"): return True return False @@ -201,7 +208,11 @@ def overrules(self, other): return False same_columns = self.spans_columns(other.get_columns()) - if same_columns and (self.is_primary() or self.is_unique()) and self.same_partial_index(other): + if ( + same_columns + and (self.is_primary() or self.is_unique()) + and self.same_partial_index(other) + ): return True return False diff --git a/orator/dbal/mysql_schema_manager.py b/orator/dbal/mysql_schema_manager.py index d9486df9..d72a5bd6 100644 --- a/orator/dbal/mysql_schema_manager.py +++ b/orator/dbal/mysql_schema_manager.py @@ -9,25 +9,24 @@ class MySQLSchemaManager(SchemaManager): - def _get_portable_table_column_definition(self, table_column): - db_type = table_column['type'].lower() - type_match = re.match('(.+)\((.*)\).*', db_type) + db_type = table_column["type"].lower() + type_match = re.match("(.+)\((.*)\).*", db_type) if type_match: db_type = type_match.group(1) - if 'length' in table_column: - length = table_column['length'] + if "length" in table_column: + length = table_column["length"] else: - if type_match and type_match.group(2) and ',' not in type_match.group(2): + if type_match and type_match.group(2) and "," not in type_match.group(2): length = int(type_match.group(2)) else: length = 0 fixed = None - if 'name' not in table_column: - table_column['name'] = '' + if "name" not in table_column: + table_column["name"] = "" precision = None scale = None @@ -35,55 +34,55 @@ def _get_portable_table_column_definition(self, table_column): type = self._platform.get_type_mapping(db_type) - if db_type in ['char', 'binary']: + if db_type in ["char", "binary"]: fixed = True - elif db_type in ['float', 'double', 'real', 'decimal', 'numeric']: - match = re.match('([A-Za-z]+\(([0-9]+),([0-9]+)\))', table_column['type']) + elif db_type in ["float", "double", "real", "decimal", "numeric"]: + match = re.match("([A-Za-z]+\(([0-9]+),([0-9]+)\))", table_column["type"]) if match: precision = match.group(1) scale = match.group(2) length = None - elif db_type == 'tinytext': + elif db_type == "tinytext": length = MySQLPlatform.LENGTH_LIMIT_TINYTEXT - elif db_type == 'text': + elif db_type == "text": length = MySQLPlatform.LENGTH_LIMIT_TEXT - elif db_type == 'mediumtext': + elif db_type == "mediumtext": length = MySQLPlatform.LENGTH_LIMIT_MEDIUMTEXT - elif db_type == 'tinyblob': + elif db_type == "tinyblob": length = MySQLPlatform.LENGTH_LIMIT_TINYBLOB - elif db_type == 'blob': + elif db_type == "blob": length = MySQLPlatform.LENGTH_LIMIT_BLOB - elif db_type == 'mediumblob': + elif db_type == "mediumblob": length = MySQLPlatform.LENGTH_LIMIT_MEDIUMBLOB - elif db_type in ['tinyint', 'smallint', 'mediumint', 'int', 'bigint', 'year']: + elif db_type in ["tinyint", "smallint", "mediumint", "int", "bigint", "year"]: length = None - elif db_type == 'enum': + elif db_type == "enum": length = None - extra['definition'] = '({})'.format(type_match.group(2)) + extra["definition"] = "({})".format(type_match.group(2)) if length is None or length == 0: length = None options = { - 'length': length, - 'unsigned': table_column['type'].find('unsigned') != -1, - 'fixed': fixed, - 'notnull': table_column['null'] != 'YES', - 'default': table_column.get('default'), - 'precision': None, - 'scale': None, - 'autoincrement': table_column['extra'].find('auto_increment') != -1, - 'extra': extra, + "length": length, + "unsigned": table_column["type"].find("unsigned") != -1, + "fixed": fixed, + "notnull": table_column["null"] != "YES", + "default": table_column.get("default"), + "precision": None, + "scale": None, + "autoincrement": table_column["extra"].find("auto_increment") != -1, + "extra": extra, } if scale is not None and precision is not None: - options['scale'] = scale - options['precision'] = precision + options["scale"] = scale + options["precision"] = precision - column = Column(table_column['field'], type, options) + column = Column(table_column["field"], type, options) - if 'collation' in table_column: - column.set_platform_option('collation', table_column['collation']) + if "collation" in table_column: + column.set_platform_option("collation", table_column["collation"]) return column @@ -91,55 +90,61 @@ def _get_portable_table_indexes_list(self, table_indexes, table_name): new = [] for v in table_indexes: v = dict((k.lower(), value) for k, value in v.items()) - if v['key_name'] == 'PRIMARY': - v['primary'] = True + if v["key_name"] == "PRIMARY": + v["primary"] = True else: - v['primary'] = False + v["primary"] = False - if 'FULLTEXT' in v['index_type']: - v['flags'] = {'FULLTEXT': True} + if "FULLTEXT" in v["index_type"]: + v["flags"] = {"FULLTEXT": True} else: - v['flags'] = {'SPATIAL': True} - + v["flags"] = {"SPATIAL": True} + new.append(v) - - return super(MySQLSchemaManager, self)._get_portable_table_indexes_list(new, table_name) + + return super(MySQLSchemaManager, self)._get_portable_table_indexes_list( + new, table_name + ) def _get_portable_table_foreign_keys_list(self, table_foreign_keys): foreign_keys = OrderedDict() for value in table_foreign_keys: value = dict((k.lower(), v) for k, v in value.items()) - name = value.get('constraint_name', '') + name = value.get("constraint_name", "") if name not in foreign_keys: - if 'delete_rule' not in value or value['delete_rule'] == 'RESTRICT': - value['delete_rule'] = '' + if "delete_rule" not in value or value["delete_rule"] == "RESTRICT": + value["delete_rule"] = "" - if 'update_rule' not in value or value['update_rule'] == 'RESTRICT': - value['update_rule'] = '' + if "update_rule" not in value or value["update_rule"] == "RESTRICT": + value["update_rule"] = "" foreign_keys[name] = { - 'name': name, - 'local': [], - 'foreign': [], - 'foreign_table': value['referenced_table_name'], - 'on_delete': value['delete_rule'], - 'on_update': value['update_rule'] + "name": name, + "local": [], + "foreign": [], + "foreign_table": value["referenced_table_name"], + "on_delete": value["delete_rule"], + "on_update": value["update_rule"], } - foreign_keys[name]['local'].append(value['column_name']) - foreign_keys[name]['foreign'].append(value['referenced_column_name']) + foreign_keys[name]["local"].append(value["column_name"]) + foreign_keys[name]["foreign"].append(value["referenced_column_name"]) result = [] for constraint in foreign_keys.values(): - result.append(ForeignKeyConstraint( - constraint['local'], constraint['foreign_table'], - constraint['foreign'], constraint['name'], - { - 'on_delete': constraint['on_delete'], - 'on_update': constraint['on_update'] - } - )) + result.append( + ForeignKeyConstraint( + constraint["local"], + constraint["foreign_table"], + constraint["foreign"], + constraint["name"], + { + "on_delete": constraint["on_delete"], + "on_update": constraint["on_update"], + }, + ) + ) return result diff --git a/orator/dbal/platforms/keywords/__init__.py b/orator/dbal/platforms/keywords/__init__.py index 633f8661..40a96afc 100644 --- a/orator/dbal/platforms/keywords/__init__.py +++ b/orator/dbal/platforms/keywords/__init__.py @@ -1,2 +1 @@ # -*- coding: utf-8 -*- - diff --git a/orator/dbal/platforms/keywords/mysql_keywords.py b/orator/dbal/platforms/keywords/mysql_keywords.py index 390f86c5..1814afa8 100644 --- a/orator/dbal/platforms/keywords/mysql_keywords.py +++ b/orator/dbal/platforms/keywords/mysql_keywords.py @@ -6,232 +6,232 @@ class MySQLKeywords(KeywordList): KEYWORDS = [ - 'ADD', - 'ALL', - 'ALTER', - 'ANALYZE', - 'AND', - 'AS', - 'ASC', - 'ASENSITIVE', - 'BEFORE', - 'BETWEEN', - 'BIGINT', - 'BINARY', - 'BLOB', - 'BOTH', - 'BY', - 'CALL', - 'CASCADE', - 'CASE', - 'CHANGE', - 'CHAR', - 'CHARACTER', - 'CHECK', - 'COLLATE', - 'COLUMN', - 'CONDITION', - 'CONNECTION', - 'CONSTRAINT', - 'CONTINUE', - 'CONVERT', - 'CREATE', - 'CROSS', - 'CURRENT_DATE', - 'CURRENT_TIME', - 'CURRENT_TIMESTAMP', - 'CURRENT_USER', - 'CURSOR', - 'DATABASE', - 'DATABASES', - 'DAY_HOUR', - 'DAY_MICROSECOND', - 'DAY_MINUTE', - 'DAY_SECOND', - 'DEC', - 'DECIMAL', - 'DECLARE', - 'DEFAULT', - 'DELAYED', - 'DELETE', - 'DESC', - 'DESCRIBE', - 'DETERMINISTIC', - 'DISTINCT', - 'DISTINCTROW', - 'DIV', - 'DOUBLE', - 'DROP', - 'DUAL', - 'EACH', - 'ELSE', - 'ELSEIF', - 'ENCLOSED', - 'ESCAPED', - 'EXISTS', - 'EXIT', - 'EXPLAIN', - 'FALSE', - 'FETCH', - 'FLOAT', - 'FLOAT4', - 'FLOAT8', - 'FOR', - 'FORCE', - 'FOREIGN', - 'FROM', - 'FULLTEXT', - 'GOTO', - 'GRANT', - 'GROUP', - 'HAVING', - 'HIGH_PRIORITY', - 'HOUR_MICROSECOND', - 'HOUR_MINUTE', - 'HOUR_SECOND', - 'IF', - 'IGNORE', - 'IN', - 'INDEX', - 'INFILE', - 'INNER', - 'INOUT', - 'INSENSITIVE', - 'INSERT', - 'INT', - 'INT1', - 'INT2', - 'INT3', - 'INT4', - 'INT8', - 'INTEGER', - 'INTERVAL', - 'INTO', - 'IS', - 'ITERATE', - 'JOIN', - 'KEY', - 'KEYS', - 'KILL', - 'LABEL', - 'LEADING', - 'LEAVE', - 'LEFT', - 'LIKE', - 'LIMIT', - 'LINES', - 'LOAD', - 'LOCALTIME', - 'LOCALTIMESTAMP', - 'LOCK', - 'LONG', - 'LONGBLOB', - 'LONGTEXT', - 'LOOP', - 'LOW_PRIORITY', - 'MATCH', - 'MEDIUMBLOB', - 'MEDIUMINT', - 'MEDIUMTEXT', - 'MIDDLEINT', - 'MINUTE_MICROSECOND', - 'MINUTE_SECOND', - 'MOD', - 'MODIFIES', - 'NATURAL', - 'NOT', - 'NO_WRITE_TO_BINLOG', - 'NULL', - 'NUMERIC', - 'ON', - 'OPTIMIZE', - 'OPTION', - 'OPTIONALLY', - 'OR', - 'ORDER', - 'OUT', - 'OUTER', - 'OUTFILE', - 'PRECISION', - 'PRIMARY', - 'PROCEDURE', - 'PURGE', - 'RAID0', - 'RANGE', - 'READ', - 'READS', - 'REAL', - 'REFERENCES', - 'REGEXP', - 'RELEASE', - 'RENAME', - 'REPEAT', - 'REPLACE', - 'REQUIRE', - 'RESTRICT', - 'RETURN', - 'REVOKE', - 'RIGHT', - 'RLIKE', - 'SCHEMA', - 'SCHEMAS', - 'SECOND_MICROSECOND', - 'SELECT', - 'SENSITIVE', - 'SEPARATOR', - 'SET', - 'SHOW', - 'SMALLINT', - 'SONAME', - 'SPATIAL', - 'SPECIFIC', - 'SQL', - 'SQLEXCEPTION', - 'SQLSTATE', - 'SQLWARNING', - 'SQL_BIG_RESULT', - 'SQL_CALC_FOUND_ROWS', - 'SQL_SMALL_RESULT', - 'SSL', - 'STARTING', - 'STRAIGHT_JOIN', - 'TABLE', - 'TERMINATED', - 'THEN', - 'TINYBLOB', - 'TINYINT', - 'TINYTEXT', - 'TO', - 'TRAILING', - 'TRIGGER', - 'TRUE', - 'UNDO', - 'UNION', - 'UNIQUE', - 'UNLOCK', - 'UNSIGNED', - 'UPDATE', - 'USAGE', - 'USE', - 'USING', - 'UTC_DATE', - 'UTC_TIME', - 'UTC_TIMESTAMP', - 'VALUES', - 'VARBINARY', - 'VARCHAR', - 'VARCHARACTER', - 'VARYING', - 'WHEN', - 'WHERE', - 'WHILE', - 'WITH', - 'WRITE', - 'X509', - 'XOR', - 'YEAR_MONTH', - 'ZEROFILL' + "ADD", + "ALL", + "ALTER", + "ANALYZE", + "AND", + "AS", + "ASC", + "ASENSITIVE", + "BEFORE", + "BETWEEN", + "BIGINT", + "BINARY", + "BLOB", + "BOTH", + "BY", + "CALL", + "CASCADE", + "CASE", + "CHANGE", + "CHAR", + "CHARACTER", + "CHECK", + "COLLATE", + "COLUMN", + "CONDITION", + "CONNECTION", + "CONSTRAINT", + "CONTINUE", + "CONVERT", + "CREATE", + "CROSS", + "CURRENT_DATE", + "CURRENT_TIME", + "CURRENT_TIMESTAMP", + "CURRENT_USER", + "CURSOR", + "DATABASE", + "DATABASES", + "DAY_HOUR", + "DAY_MICROSECOND", + "DAY_MINUTE", + "DAY_SECOND", + "DEC", + "DECIMAL", + "DECLARE", + "DEFAULT", + "DELAYED", + "DELETE", + "DESC", + "DESCRIBE", + "DETERMINISTIC", + "DISTINCT", + "DISTINCTROW", + "DIV", + "DOUBLE", + "DROP", + "DUAL", + "EACH", + "ELSE", + "ELSEIF", + "ENCLOSED", + "ESCAPED", + "EXISTS", + "EXIT", + "EXPLAIN", + "FALSE", + "FETCH", + "FLOAT", + "FLOAT4", + "FLOAT8", + "FOR", + "FORCE", + "FOREIGN", + "FROM", + "FULLTEXT", + "GOTO", + "GRANT", + "GROUP", + "HAVING", + "HIGH_PRIORITY", + "HOUR_MICROSECOND", + "HOUR_MINUTE", + "HOUR_SECOND", + "IF", + "IGNORE", + "IN", + "INDEX", + "INFILE", + "INNER", + "INOUT", + "INSENSITIVE", + "INSERT", + "INT", + "INT1", + "INT2", + "INT3", + "INT4", + "INT8", + "INTEGER", + "INTERVAL", + "INTO", + "IS", + "ITERATE", + "JOIN", + "KEY", + "KEYS", + "KILL", + "LABEL", + "LEADING", + "LEAVE", + "LEFT", + "LIKE", + "LIMIT", + "LINES", + "LOAD", + "LOCALTIME", + "LOCALTIMESTAMP", + "LOCK", + "LONG", + "LONGBLOB", + "LONGTEXT", + "LOOP", + "LOW_PRIORITY", + "MATCH", + "MEDIUMBLOB", + "MEDIUMINT", + "MEDIUMTEXT", + "MIDDLEINT", + "MINUTE_MICROSECOND", + "MINUTE_SECOND", + "MOD", + "MODIFIES", + "NATURAL", + "NOT", + "NO_WRITE_TO_BINLOG", + "NULL", + "NUMERIC", + "ON", + "OPTIMIZE", + "OPTION", + "OPTIONALLY", + "OR", + "ORDER", + "OUT", + "OUTER", + "OUTFILE", + "PRECISION", + "PRIMARY", + "PROCEDURE", + "PURGE", + "RAID0", + "RANGE", + "READ", + "READS", + "REAL", + "REFERENCES", + "REGEXP", + "RELEASE", + "RENAME", + "REPEAT", + "REPLACE", + "REQUIRE", + "RESTRICT", + "RETURN", + "REVOKE", + "RIGHT", + "RLIKE", + "SCHEMA", + "SCHEMAS", + "SECOND_MICROSECOND", + "SELECT", + "SENSITIVE", + "SEPARATOR", + "SET", + "SHOW", + "SMALLINT", + "SONAME", + "SPATIAL", + "SPECIFIC", + "SQL", + "SQLEXCEPTION", + "SQLSTATE", + "SQLWARNING", + "SQL_BIG_RESULT", + "SQL_CALC_FOUND_ROWS", + "SQL_SMALL_RESULT", + "SSL", + "STARTING", + "STRAIGHT_JOIN", + "TABLE", + "TERMINATED", + "THEN", + "TINYBLOB", + "TINYINT", + "TINYTEXT", + "TO", + "TRAILING", + "TRIGGER", + "TRUE", + "UNDO", + "UNION", + "UNIQUE", + "UNLOCK", + "UNSIGNED", + "UPDATE", + "USAGE", + "USE", + "USING", + "UTC_DATE", + "UTC_TIME", + "UTC_TIMESTAMP", + "VALUES", + "VARBINARY", + "VARCHAR", + "VARCHARACTER", + "VARYING", + "WHEN", + "WHERE", + "WHILE", + "WITH", + "WRITE", + "X509", + "XOR", + "YEAR_MONTH", + "ZEROFILL", ] def get_name(self): - return 'MySQL' + return "MySQL" diff --git a/orator/dbal/platforms/keywords/postgresql_keywords.py b/orator/dbal/platforms/keywords/postgresql_keywords.py index afd7b1fa..3bf1d6b0 100644 --- a/orator/dbal/platforms/keywords/postgresql_keywords.py +++ b/orator/dbal/platforms/keywords/postgresql_keywords.py @@ -6,94 +6,94 @@ class PostgreSQLKeywords(KeywordList): KEYWORDS = [ - 'ALL', - 'ANALYSE', - 'ANALYZE', - 'AND', - 'ANY', - 'AS', - 'ASC', - 'AUTHORIZATION', - 'BETWEEN', - 'BINARY', - 'BOTH', - 'CASE', - 'CAST', - 'CHECK', - 'COLLATE', - 'COLUMN', - 'CONSTRAINT', - 'CREATE', - 'CURRENT_DATE', - 'CURRENT_TIME', - 'CURRENT_TIMESTAMP', - 'CURRENT_USER', - 'DEFAULT', - 'DEFERRABLE', - 'DESC', - 'DISTINCT', - 'DO', - 'ELSE', - 'END', - 'EXCEPT', - 'FALSE', - 'FOR', - 'FOREIGN', - 'FREEZE', - 'FROM', - 'FULL', - 'GRANT', - 'GROUP', - 'HAVING', - 'ILIKE', - 'IN', - 'INITIALLY', - 'INNER', - 'INTERSECT', - 'INTO', - 'IS', - 'ISNULL', - 'JOIN', - 'LEADING', - 'LEFT', - 'LIKE', - 'LIMIT', - 'LOCALTIME', - 'LOCALTIMESTAMP', - 'NATURAL', - 'NEW', - 'NOT', - 'NOTNULL', - 'NULL', - 'OFF', - 'OFFSET', - 'OLD', - 'ON', - 'ONLY', - 'OR', - 'ORDER', - 'OUTER', - 'OVERLAPS', - 'PLACING', - 'PRIMARY', - 'REFERENCES', - 'SELECT', - 'SESSION_USER', - 'SIMILAR', - 'SOME', - 'TABLE', - 'THEN', - 'TO', - 'TRAILING', - 'TRUE', - 'UNION', - 'UNIQUE', - 'USER', - 'USING', - 'VERBOSE', - 'WHEN', - 'WHERE' + "ALL", + "ANALYSE", + "ANALYZE", + "AND", + "ANY", + "AS", + "ASC", + "AUTHORIZATION", + "BETWEEN", + "BINARY", + "BOTH", + "CASE", + "CAST", + "CHECK", + "COLLATE", + "COLUMN", + "CONSTRAINT", + "CREATE", + "CURRENT_DATE", + "CURRENT_TIME", + "CURRENT_TIMESTAMP", + "CURRENT_USER", + "DEFAULT", + "DEFERRABLE", + "DESC", + "DISTINCT", + "DO", + "ELSE", + "END", + "EXCEPT", + "FALSE", + "FOR", + "FOREIGN", + "FREEZE", + "FROM", + "FULL", + "GRANT", + "GROUP", + "HAVING", + "ILIKE", + "IN", + "INITIALLY", + "INNER", + "INTERSECT", + "INTO", + "IS", + "ISNULL", + "JOIN", + "LEADING", + "LEFT", + "LIKE", + "LIMIT", + "LOCALTIME", + "LOCALTIMESTAMP", + "NATURAL", + "NEW", + "NOT", + "NOTNULL", + "NULL", + "OFF", + "OFFSET", + "OLD", + "ON", + "ONLY", + "OR", + "ORDER", + "OUTER", + "OVERLAPS", + "PLACING", + "PRIMARY", + "REFERENCES", + "SELECT", + "SESSION_USER", + "SIMILAR", + "SOME", + "TABLE", + "THEN", + "TO", + "TRAILING", + "TRUE", + "UNION", + "UNIQUE", + "USER", + "USING", + "VERBOSE", + "WHEN", + "WHERE", ] def get_name(self): - return 'PostgreSQL' + return "PostgreSQL" diff --git a/orator/dbal/platforms/keywords/sqlite_keywords.py b/orator/dbal/platforms/keywords/sqlite_keywords.py index bb62f587..390896c0 100644 --- a/orator/dbal/platforms/keywords/sqlite_keywords.py +++ b/orator/dbal/platforms/keywords/sqlite_keywords.py @@ -6,128 +6,128 @@ class SQLiteKeywords(KeywordList): KEYWORDS = [ - 'ABORT', - 'ACTION', - 'ADD', - 'AFTER', - 'ALL', - 'ALTER', - 'ANALYZE', - 'AND', - 'AS', - 'ASC', - 'ATTACH', - 'AUTOINCREMENT', - 'BEFORE', - 'BEGIN', - 'BETWEEN', - 'BY', - 'CASCADE', - 'CASE', - 'CAST', - 'CHECK', - 'COLLATE', - 'COLUMN', - 'COMMIT', - 'CONFLICT', - 'CONSTRAINT', - 'CREATE', - 'CROSS', - 'CURRENT_DATE', - 'CURRENT_TIME', - 'CURRENT_TIMESTAMP', - 'DATABASE', - 'DEFAULT', - 'DEFERRABLE', - 'DEFERRED', - 'DELETE', - 'DESC', - 'DETACH', - 'DISTINCT', - 'DROP', - 'EACH', - 'ELSE', - 'END', - 'ESCAPE', - 'EXCEPT', - 'EXCLUSIVE', - 'EXISTS', - 'EXPLAIN', - 'FAIL', - 'FOR', - 'FOREIGN', - 'FROM', - 'FULL', - 'GLOB', - 'GROUP', - 'HAVING', - 'IF', - 'IGNORE', - 'IMMEDIATE', - 'IN', - 'INDEX', - 'INDEXED', - 'INITIALLY', - 'INNER', - 'INSERT', - 'INSTEAD', - 'INTERSECT', - 'INTO', - 'IS', - 'ISNULL', - 'JOIN', - 'KEY', - 'LEFT', - 'LIKE', - 'LIMIT', - 'MATCH', - 'NATURAL', - 'NO', - 'NOT', - 'NOTNULL', - 'NULL', - 'OF', - 'OFFSET', - 'ON', - 'OR', - 'ORDER', - 'OUTER', - 'PLAN', - 'PRAGMA', - 'PRIMARY', - 'QUERY', - 'RAISE', - 'REFERENCES', - 'REGEXP', - 'REINDEX', - 'RELEASE', - 'RENAME', - 'REPLACE', - 'RESTRICT', - 'RIGHT', - 'ROLLBACK', - 'ROW', - 'SAVEPOINT', - 'SELECT', - 'SET', - 'TABLE', - 'TEMP', - 'TEMPORARY', - 'THEN', - 'TO', - 'TRANSACTION', - 'TRIGGER', - 'UNION', - 'UNIQUE', - 'UPDATE', - 'USING', - 'VACUUM', - 'VALUES', - 'VIEW', - 'VIRTUAL', - 'WHEN', - 'WHERE' + "ABORT", + "ACTION", + "ADD", + "AFTER", + "ALL", + "ALTER", + "ANALYZE", + "AND", + "AS", + "ASC", + "ATTACH", + "AUTOINCREMENT", + "BEFORE", + "BEGIN", + "BETWEEN", + "BY", + "CASCADE", + "CASE", + "CAST", + "CHECK", + "COLLATE", + "COLUMN", + "COMMIT", + "CONFLICT", + "CONSTRAINT", + "CREATE", + "CROSS", + "CURRENT_DATE", + "CURRENT_TIME", + "CURRENT_TIMESTAMP", + "DATABASE", + "DEFAULT", + "DEFERRABLE", + "DEFERRED", + "DELETE", + "DESC", + "DETACH", + "DISTINCT", + "DROP", + "EACH", + "ELSE", + "END", + "ESCAPE", + "EXCEPT", + "EXCLUSIVE", + "EXISTS", + "EXPLAIN", + "FAIL", + "FOR", + "FOREIGN", + "FROM", + "FULL", + "GLOB", + "GROUP", + "HAVING", + "IF", + "IGNORE", + "IMMEDIATE", + "IN", + "INDEX", + "INDEXED", + "INITIALLY", + "INNER", + "INSERT", + "INSTEAD", + "INTERSECT", + "INTO", + "IS", + "ISNULL", + "JOIN", + "KEY", + "LEFT", + "LIKE", + "LIMIT", + "MATCH", + "NATURAL", + "NO", + "NOT", + "NOTNULL", + "NULL", + "OF", + "OFFSET", + "ON", + "OR", + "ORDER", + "OUTER", + "PLAN", + "PRAGMA", + "PRIMARY", + "QUERY", + "RAISE", + "REFERENCES", + "REGEXP", + "REINDEX", + "RELEASE", + "RENAME", + "REPLACE", + "RESTRICT", + "RIGHT", + "ROLLBACK", + "ROW", + "SAVEPOINT", + "SELECT", + "SET", + "TABLE", + "TEMP", + "TEMPORARY", + "THEN", + "TO", + "TRANSACTION", + "TRIGGER", + "UNION", + "UNIQUE", + "UPDATE", + "USING", + "VACUUM", + "VALUES", + "VIEW", + "VIRTUAL", + "WHEN", + "WHERE", ] def get_name(self): - return 'SQLite' + return "SQLite" diff --git a/orator/dbal/platforms/mysql57_platform.py b/orator/dbal/platforms/mysql57_platform.py index 3e0d8a10..f10638fd 100644 --- a/orator/dbal/platforms/mysql57_platform.py +++ b/orator/dbal/platforms/mysql57_platform.py @@ -6,45 +6,45 @@ class MySQL57Platform(MySQLPlatform): INTERNAL_TYPE_MAPPING = { - 'tinyint': 'boolean', - 'smallint': 'smallint', - 'mediumint': 'integer', - 'int': 'integer', - 'integer': 'integer', - 'bigint': 'bigint', - 'int8': 'bigint', - 'bool': 'boolean', - 'boolean': 'boolean', - 'tinytext': 'text', - 'mediumtext': 'text', - 'longtext': 'text', - 'text': 'text', - 'varchar': 'string', - 'string': 'string', - 'char': 'string', - 'date': 'date', - 'datetime': 'datetime', - 'timestamp': 'datetime', - 'time': 'time', - 'float': 'float', - 'double': 'float', - 'real': 'float', - 'decimal': 'decimal', - 'numeric': 'decimal', - 'year': 'date', - 'longblob': 'blob', - 'blob': 'blob', - 'mediumblob': 'blob', - 'tinyblob': 'blob', - 'binary': 'binary', - 'varbinary': 'binary', - 'set': 'simple_array', - 'enum': 'enum', - 'json': 'json', + "tinyint": "boolean", + "smallint": "smallint", + "mediumint": "integer", + "int": "integer", + "integer": "integer", + "bigint": "bigint", + "int8": "bigint", + "bool": "boolean", + "boolean": "boolean", + "tinytext": "text", + "mediumtext": "text", + "longtext": "text", + "text": "text", + "varchar": "string", + "string": "string", + "char": "string", + "date": "date", + "datetime": "datetime", + "timestamp": "datetime", + "time": "time", + "float": "float", + "double": "float", + "real": "float", + "decimal": "decimal", + "numeric": "decimal", + "year": "date", + "longblob": "blob", + "blob": "blob", + "mediumblob": "blob", + "tinyblob": "blob", + "binary": "binary", + "varbinary": "binary", + "set": "simple_array", + "enum": "enum", + "json": "json", } def get_json_type_declaration_sql(self, column): - return 'JSON' + return "JSON" def has_native_json_type(self): return True diff --git a/orator/dbal/platforms/mysql_platform.py b/orator/dbal/platforms/mysql_platform.py index 8de03921..75a763ff 100644 --- a/orator/dbal/platforms/mysql_platform.py +++ b/orator/dbal/platforms/mysql_platform.py @@ -16,53 +16,55 @@ class MySQLPlatform(Platform): LENGTH_LIMIT_MEDIUMBLOB = 16777215 INTERNAL_TYPE_MAPPING = { - 'tinyint': 'boolean', - 'smallint': 'smallint', - 'mediumint': 'integer', - 'int': 'integer', - 'integer': 'integer', - 'bigint': 'bigint', - 'int8': 'bigint', - 'bool': 'boolean', - 'boolean': 'boolean', - 'tinytext': 'text', - 'mediumtext': 'text', - 'longtext': 'text', - 'text': 'text', - 'varchar': 'string', - 'string': 'string', - 'char': 'string', - 'date': 'date', - 'datetime': 'datetime', - 'timestamp': 'datetime', - 'time': 'time', - 'float': 'float', - 'double': 'float', - 'real': 'float', - 'decimal': 'decimal', - 'numeric': 'decimal', - 'year': 'date', - 'longblob': 'blob', - 'blob': 'blob', - 'mediumblob': 'blob', - 'tinyblob': 'blob', - 'binary': 'binary', - 'varbinary': 'binary', - 'set': 'simple_array', - 'enum': 'enum', + "tinyint": "boolean", + "smallint": "smallint", + "mediumint": "integer", + "int": "integer", + "integer": "integer", + "bigint": "bigint", + "int8": "bigint", + "bool": "boolean", + "boolean": "boolean", + "tinytext": "text", + "mediumtext": "text", + "longtext": "text", + "text": "text", + "varchar": "string", + "string": "string", + "char": "string", + "date": "date", + "datetime": "datetime", + "timestamp": "datetime", + "time": "time", + "float": "float", + "double": "float", + "real": "float", + "decimal": "decimal", + "numeric": "decimal", + "year": "date", + "longblob": "blob", + "blob": "blob", + "mediumblob": "blob", + "tinyblob": "blob", + "binary": "binary", + "varbinary": "binary", + "set": "simple_array", + "enum": "enum", } def get_list_table_columns_sql(self, table, database=None): if database: database = "'%s'" % database else: - database = 'DATABASE()' + database = "DATABASE()" - return 'SELECT COLUMN_NAME AS field, COLUMN_TYPE AS type, IS_NULLABLE AS `null`, ' \ - 'COLUMN_KEY AS `key`, COLUMN_DEFAULT AS `default`, EXTRA AS extra, COLUMN_COMMENT AS comment, ' \ - 'CHARACTER_SET_NAME AS character_set, COLLATION_NAME AS collation ' \ - 'FROM information_schema.COLUMNS WHERE TABLE_SCHEMA = %s AND TABLE_NAME = \'%s\''\ - % (database, table) + return ( + "SELECT COLUMN_NAME AS field, COLUMN_TYPE AS type, IS_NULLABLE AS `null`, " + "COLUMN_KEY AS `key`, COLUMN_DEFAULT AS `default`, EXTRA AS extra, COLUMN_COMMENT AS comment, " + "CHARACTER_SET_NAME AS character_set, COLLATION_NAME AS collation " + "FROM information_schema.COLUMNS WHERE TABLE_SCHEMA = %s AND TABLE_NAME = '%s'" + % (database, table) + ) def get_list_table_indexes_sql(self, table, current_database=None): sql = """ @@ -74,21 +76,25 @@ def get_list_table_indexes_sql(self, table, current_database=None): """ if current_database: - sql += ' AND TABLE_SCHEMA = \'%s\'' % current_database + sql += " AND TABLE_SCHEMA = '%s'" % current_database return sql % table def get_list_table_foreign_keys_sql(self, table, database=None): - sql = ("SELECT DISTINCT k.`CONSTRAINT_NAME`, k.`COLUMN_NAME`, k.`REFERENCED_TABLE_NAME`, " - "k.`REFERENCED_COLUMN_NAME` /*!50116 , c.update_rule, c.delete_rule */ " - "FROM information_schema.key_column_usage k /*!50116 " - "INNER JOIN information_schema.referential_constraints c ON " - " c.constraint_name = k.constraint_name AND " - " c.table_name = '%s' */ WHERE k.table_name = '%s'" % (table, table)) + sql = ( + "SELECT DISTINCT k.`CONSTRAINT_NAME`, k.`COLUMN_NAME`, k.`REFERENCED_TABLE_NAME`, " + "k.`REFERENCED_COLUMN_NAME` /*!50116 , c.update_rule, c.delete_rule */ " + "FROM information_schema.key_column_usage k /*!50116 " + "INNER JOIN information_schema.referential_constraints c ON " + " c.constraint_name = k.constraint_name AND " + " c.table_name = '%s' */ WHERE k.table_name = '%s'" % (table, table) + ) if database: - sql += " AND k.table_schema = '%s' /*!50116 AND c.constraint_schema = '%s' */"\ - % (database, database) + sql += ( + " AND k.table_schema = '%s' /*!50116 AND c.constraint_schema = '%s' */" + % (database, database) + ) sql += " AND k.`REFERENCED_COLUMN_NAME` IS NOT NULL" @@ -107,7 +113,9 @@ def get_alter_table_sql(self, diff): query_parts = [] if diff.new_name is not False: - query_parts.append('RENAME TO %s' % diff.get_new_name().get_quoted_name(self)) + query_parts.append( + "RENAME TO %s" % diff.get_new_name().get_quoted_name(self) + ) # Added columns? @@ -118,30 +126,43 @@ def get_alter_table_sql(self, diff): column_dict = column.to_dict() # Don't propagate default value changes for unsupported column types. - if column_diff.has_changed('default') \ - and len(column_diff.changed_properties) == 1 \ - and (column_dict['type'] == 'text' or column_dict['type'] == 'blob'): + if ( + column_diff.has_changed("default") + and len(column_diff.changed_properties) == 1 + and (column_dict["type"] == "text" or column_dict["type"] == "blob") + ): continue - query_parts.append('CHANGE %s %s' - % (column_diff.get_old_column_name().get_quoted_name(self), - self.get_column_declaration_sql(column.get_quoted_name(self), column_dict))) + query_parts.append( + "CHANGE %s %s" + % ( + column_diff.get_old_column_name().get_quoted_name(self), + self.get_column_declaration_sql( + column.get_quoted_name(self), column_dict + ), + ) + ) for old_column_name, column in diff.renamed_columns.items(): column_dict = column.to_dict() old_column_name = Identifier(old_column_name) - query_parts.append('CHANGE %s %s' - % (self.quote(old_column_name.get_quoted_name(self)), - self.get_column_declaration_sql( - self.quote(column.get_quoted_name(self)), - column_dict))) + query_parts.append( + "CHANGE %s %s" + % ( + self.quote(old_column_name.get_quoted_name(self)), + self.get_column_declaration_sql( + self.quote(column.get_quoted_name(self)), column_dict + ), + ) + ) sql = [] if len(query_parts) > 0: - sql.append('ALTER TABLE %s %s' - % (diff.get_name(self).get_quoted_name(self), - ', '.join(query_parts))) + sql.append( + "ALTER TABLE %s %s" + % (diff.get_name(self).get_quoted_name(self), ", ".join(query_parts)) + ) return sql @@ -156,85 +177,85 @@ def convert_booleans(self, item): return item def get_boolean_type_declaration_sql(self, column): - return 'TINYINT(1)' + return "TINYINT(1)" def get_integer_type_declaration_sql(self, column): - return 'INT ' + self._get_common_integer_type_declaration_sql(column) + return "INT " + self._get_common_integer_type_declaration_sql(column) def get_bigint_type_declaration_sql(self, column): - return 'BIGINT ' + self._get_common_integer_type_declaration_sql(column) + return "BIGINT " + self._get_common_integer_type_declaration_sql(column) def get_smallint_type_declaration_sql(self, column): - return 'SMALLINT ' + self._get_common_integer_type_declaration_sql(column) + return "SMALLINT " + self._get_common_integer_type_declaration_sql(column) def get_guid_type_declaration_sql(self, column): - return 'UUID' + return "UUID" def get_datetime_type_declaration_sql(self, column): - if 'version' in column and column['version'] == True: - return 'TIMESTAMP' + if "version" in column and column["version"] == True: + return "TIMESTAMP" - return 'DATETIME' + return "DATETIME" def get_date_type_declaration_sql(self, column): - return 'DATE' + return "DATE" def get_time_type_declaration_sql(self, column): - return 'TIME' + return "TIME" def get_varchar_type_declaration_sql_snippet(self, length, fixed): if fixed: - return 'CHAR(%s)' % length if length else 'CHAR(255)' + return "CHAR(%s)" % length if length else "CHAR(255)" else: - return 'VARCHAR(%s)' % length if length else 'VARCHAR(255)' + return "VARCHAR(%s)" % length if length else "VARCHAR(255)" def get_binary_type_declaration_sql_snippet(self, length, fixed): if fixed: - return 'BINARY(%s)' % (length or 255) + return "BINARY(%s)" % (length or 255) else: - return 'VARBINARY(%s)' % (length or 255) + return "VARBINARY(%s)" % (length or 255) def get_text_type_declaration_sql(self, column): - length = column.get('length') + length = column.get("length") if length: if length <= self.LENGTH_LIMIT_TINYTEXT: - return 'TINYTEXT' + return "TINYTEXT" if length <= self.LENGTH_LIMIT_TEXT: - return 'TEXT' + return "TEXT" if length <= self.LENGTH_LIMIT_MEDIUMTEXT: - return 'MEDIUMTEXT' + return "MEDIUMTEXT" - return 'LONGTEXT' + return "LONGTEXT" def get_blob_type_declaration_sql(self, column): - length = column.get('length') + length = column.get("length") if length: if length <= self.LENGTH_LIMIT_TINYBLOB: - return 'TINYBLOB' + return "TINYBLOB" if length <= self.LENGTH_LIMIT_BLOB: - return 'BLOB' + return "BLOB" if length <= self.LENGTH_LIMIT_MEDIUMBLOB: - return 'MEDIUMBLOB' + return "MEDIUMBLOB" - return 'LONGBLOB' + return "LONGBLOB" def get_clob_type_declaration_sql(self, column): - length = column.get('length') + length = column.get("length") if length: if length <= self.LENGTH_LIMIT_TINYTEXT: - return 'TINYTEXT' + return "TINYTEXT" if length <= self.LENGTH_LIMIT_TEXT: - return 'TEXT' + return "TEXT" if length <= self.LENGTH_LIMIT_MEDIUMTEXT: - return 'MEDIUMTEXT' + return "MEDIUMTEXT" - return 'LONGTEXT' + return "LONGTEXT" def get_decimal_type_declaration_sql(self, column): decl = super(MySQLPlatform, self).get_decimal_type_declaration_sql(column) @@ -242,23 +263,23 @@ def get_decimal_type_declaration_sql(self, column): return decl + self.get_unsigned_declaration(column) def get_unsigned_declaration(self, column): - if column.get('unsigned'): - return ' UNSIGNED' + if column.get("unsigned"): + return " UNSIGNED" - return '' + return "" def _get_common_integer_type_declaration_sql(self, column): - autoinc = '' - if column.get('autoincrement'): - autoinc = ' AUTO_INCREMENT' + autoinc = "" + if column.get("autoincrement"): + autoinc = " AUTO_INCREMENT" return self.get_unsigned_declaration(column) + autoinc def get_float_type_declaration_sql(self, column): - return 'DOUBLE PRECISION' + self.get_unsigned_declaration(column) + return "DOUBLE PRECISION" + self.get_unsigned_declaration(column) def get_enum_type_declaration_sql(self, column): - return 'ENUM{}'.format(column['extra']['definition']) + return "ENUM{}".format(column["extra"]["definition"]) def supports_foreign_key_constraints(self): return True @@ -267,10 +288,10 @@ def supports_column_collation(self): return False def quote(self, name): - return '`%s`' % name.replace('`', '``') + return "`%s`" % name.replace("`", "``") def _get_reserved_keywords_class(self): return MySQLKeywords def get_identifier_quote_character(self): - return '`' + return "`" diff --git a/orator/dbal/platforms/platform.py b/orator/dbal/platforms/platform.py index 438c13a4..9814a69d 100644 --- a/orator/dbal/platforms/platform.py +++ b/orator/dbal/platforms/platform.py @@ -22,30 +22,39 @@ def __init__(self, version=None): self._version = None def get_default_value_declaration_sql(self, field): - default = '' - - if not field.get('notnull'): - default = ' DEFAULT NULL' - - if 'default' in field and field['default'] is not None: - default = ' DEFAULT \'%s\'' % field['default'] - - if 'type' in field: - type = field['type'] - - if type in ['integer', 'bigint', 'smallint']: - default = ' DEFAULT %s' % field['default'] - elif type in ['datetime', 'datetimetz'] \ - and field['default'] in [self.get_current_timestamp_sql(), 'NOW', 'now']: - default = ' DEFAULT %s' % self.get_current_timestamp_sql() - elif type in ['time'] \ - and field['default'] in [self.get_current_time_sql(), 'NOW', 'now']: - default = ' DEFAULT %s' % self.get_current_time_sql() - elif type in ['date'] \ - and field['default'] in [self.get_current_date_sql(), 'NOW', 'now']: - default = ' DEFAULT %s' % self.get_current_date_sql() - elif type in ['boolean']: - default = ' DEFAULT \'%s\'' % self.convert_booleans(field['default']) + default = "" + + if not field.get("notnull"): + default = " DEFAULT NULL" + + if "default" in field and field["default"] is not None: + default = " DEFAULT '%s'" % field["default"] + + if "type" in field: + type = field["type"] + + if type in ["integer", "bigint", "smallint"]: + default = " DEFAULT %s" % field["default"] + elif type in ["datetime", "datetimetz"] and field["default"] in [ + self.get_current_timestamp_sql(), + "NOW", + "now", + ]: + default = " DEFAULT %s" % self.get_current_timestamp_sql() + elif type in ["time"] and field["default"] in [ + self.get_current_time_sql(), + "NOW", + "now", + ]: + default = " DEFAULT %s" % self.get_current_time_sql() + elif type in ["date"] and field["default"] in [ + self.get_current_date_sql(), + "NOW", + "now", + ]: + default = " DEFAULT %s" % self.get_current_date_sql() + elif type in ["boolean"]: + default = " DEFAULT '%s'" % self.convert_booleans(field["default"]) return default @@ -73,15 +82,15 @@ def get_check_declaration_sql(self, definition): constraints = [] for field, def_ in definition.items(): if isinstance(def_, basestring): - constraints.append('CHECK (%s)' % def_) + constraints.append("CHECK (%s)" % def_) else: - if 'min' in def_: - constraints.append('CHECK (%s >= %s)' % (field, def_['min'])) + if "min" in def_: + constraints.append("CHECK (%s >= %s)" % (field, def_["min"])) - if 'max' in def_: - constraints.append('CHECK (%s <= %s)' % (field, def_['max'])) + if "max" in def_: + constraints.append("CHECK (%s <= %s)" % (field, def_["max"])) - return ', '.join(constraints) + return ", ".join(constraints) def get_unique_constraint_declaration_sql(self, name, index): """ @@ -103,10 +112,11 @@ def get_unique_constraint_declaration_sql(self, name, index): if not columns: raise DBALException('Incomplete definition. "columns" required.') - return 'CONSTRAINT %s UNIQUE (%s)%s'\ - % (name.get_quoted_name(self), - self.get_index_field_declaration_list_sql(columns), - self.get_partial_index_sql(index)) + return "CONSTRAINT %s UNIQUE (%s)%s" % ( + name.get_quoted_name(self), + self.get_index_field_declaration_list_sql(columns), + self.get_partial_index_sql(index), + ) def get_index_declaration_sql(self, name, index): """ @@ -128,11 +138,12 @@ def get_index_declaration_sql(self, name, index): if not columns: raise DBALException('Incomplete definition. "columns" required.') - return '%sINDEX %s (%s)%s'\ - % (self.get_create_index_sql_flags(index), - name.get_quoted_name(self), - self.get_index_field_declaration_list_sql(columns), - self.get_partial_index_sql(index)) + return "%sINDEX %s (%s)%s" % ( + self.get_create_index_sql_flags(index), + name.get_quoted_name(self), + self.get_index_field_declaration_list_sql(columns), + self.get_partial_index_sql(index), + ) def get_foreign_key_declaration_sql(self, foreign_key): """ @@ -159,12 +170,18 @@ def get_advanced_foreign_key_options_sql(self, foreign_key): :rtype: str """ - query = '' - if self.supports_foreign_key_on_update() and foreign_key.has_option('on_update'): - query += ' ON UPDATE %s' % self.get_foreign_key_referential_action_sql(foreign_key.get_option('on_update')) + query = "" + if self.supports_foreign_key_on_update() and foreign_key.has_option( + "on_update" + ): + query += " ON UPDATE %s" % self.get_foreign_key_referential_action_sql( + foreign_key.get_option("on_update") + ) - if foreign_key.has_option('on_delete'): - query += ' ON DELETE %s' % self.get_foreign_key_referential_action_sql(foreign_key.get_option('on_delete')) + if foreign_key.has_option("on_delete"): + query += " ON DELETE %s" % self.get_foreign_key_referential_action_sql( + foreign_key.get_option("on_delete") + ) return query @@ -178,8 +195,14 @@ def get_foreign_key_referential_action_sql(self, action): :rtype: str """ action = action.upper() - if action not in ['CASCADE', 'SET NULL', 'NO ACTION', 'RESTRICT', 'SET DEFAULT']: - raise DBALException('Invalid foreign key action: %s' % action) + if action not in [ + "CASCADE", + "SET NULL", + "NO ACTION", + "RESTRICT", + "SET DEFAULT", + ]: + raise DBALException("Invalid foreign key action: %s" % action) return action @@ -193,11 +216,11 @@ def get_foreign_key_base_declaration_sql(self, foreign_key): :rtype: str """ - sql = '' + sql = "" if foreign_key.get_name(): - sql += 'CONSTRAINT %s ' % foreign_key.get_quoted_name(self) + sql += "CONSTRAINT %s " % foreign_key.get_quoted_name(self) - sql += 'FOREIGN KEY (' + sql += "FOREIGN KEY (" if not foreign_key.get_local_columns(): raise DBALException('Incomplete definition. "local" required.') @@ -208,26 +231,27 @@ def get_foreign_key_base_declaration_sql(self, foreign_key): if not foreign_key.get_foreign_table_name(): raise DBALException('Incomplete definition. "foreign_table" required.') - sql += '%s) REFERENCES %s (%s)'\ - % (', '.join(foreign_key.get_quoted_local_columns(self)), - foreign_key.get_quoted_foreign_table_name(self), - ', '.join(foreign_key.get_quoted_foreign_columns(self))) + sql += "%s) REFERENCES %s (%s)" % ( + ", ".join(foreign_key.get_quoted_local_columns(self)), + foreign_key.get_quoted_foreign_table_name(self), + ", ".join(foreign_key.get_quoted_foreign_columns(self)), + ) return sql def get_current_date_sql(self): - return 'CURRENT_DATE' + return "CURRENT_DATE" def get_current_time_sql(self): - return 'CURRENT_TIME' + return "CURRENT_TIME" def get_current_timestamp_sql(self): - return 'CURRENT_TIMESTAMP' + return "CURRENT_TIMESTAMP" def get_sql_type_declaration(self, column): - internal_type = column['type'] + internal_type = column["type"] - return getattr(self, 'get_%s_type_declaration_sql' % internal_type)(column) + return getattr(self, "get_%s_type_declaration_sql" % internal_type)(column) def get_column_declaration_list_sql(self, fields): """ @@ -238,95 +262,97 @@ def get_column_declaration_list_sql(self, fields): for name, field in fields.items(): query_fields.append(self.get_column_declaration_sql(name, field)) - return ', '.join(query_fields) + return ", ".join(query_fields) def get_column_declaration_sql(self, name, field): - if 'column_definition' in field: + if "column_definition" in field: column_def = self.get_custom_type_declaration_sql(field) else: default = self.get_default_value_declaration_sql(field) - charset = field.get('charset', '') + charset = field.get("charset", "") if charset: - charset = ' ' + self.get_column_charset_declaration_sql(charset) + charset = " " + self.get_column_charset_declaration_sql(charset) - collation = field.get('collation', '') + collation = field.get("collation", "") if charset: - charset = ' ' + self.get_column_collation_declaration_sql(charset) + charset = " " + self.get_column_collation_declaration_sql(charset) - notnull = field.get('notnull', '') + notnull = field.get("notnull", "") if notnull: - notnull = ' NOT NULL' + notnull = " NOT NULL" else: - notnull = '' + notnull = "" - unique = field.get('unique', '') + unique = field.get("unique", "") if unique: - unique = ' ' + self.get_unique_field_declaration_sql() + unique = " " + self.get_unique_field_declaration_sql() else: - unique = '' + unique = "" - check = field.get('check', '') + check = field.get("check", "") type_decl = self.get_sql_type_declaration(field) - column_def = type_decl + charset + default + notnull + unique + check + collation + column_def = ( + type_decl + charset + default + notnull + unique + check + collation + ) - return name + ' ' + column_def + return name + " " + column_def def get_custom_type_declaration_sql(self, column_def): - return column_def['column_definition'] + return column_def["column_definition"] def get_column_charset_declaration_sql(self, charset): - return '' + return "" def get_column_collation_declaration_sql(self, collation): if self.supports_column_collation(): - return 'COLLATE %s' % collation + return "COLLATE %s" % collation - return '' + return "" def supports_column_collation(self): return False def get_unique_field_declaration_sql(self): - return 'UNIQUE' + return "UNIQUE" def get_string_type_declaration_sql(self, column): - if 'length' not in column: - column['length'] = self.get_varchar_default_length() + if "length" not in column: + column["length"] = self.get_varchar_default_length() - fixed = column.get('fixed', False) + fixed = column.get("fixed", False) - if column['length'] > self.get_varchar_max_length(): + if column["length"] > self.get_varchar_max_length(): return self.get_clob_type_declaration_sql(column) - return self.get_varchar_type_declaration_sql_snippet(column['length'], fixed) + return self.get_varchar_type_declaration_sql_snippet(column["length"], fixed) def get_binary_type_declaration_sql(self, column): - if 'length' not in column: - column['length'] = self.get_binary_default_length() + if "length" not in column: + column["length"] = self.get_binary_default_length() - fixed = column.get('fixed', False) + fixed = column.get("fixed", False) - if column['length'] > self.get_binary_max_length(): + if column["length"] > self.get_binary_max_length(): return self.get_blob_type_declaration_sql(column) - return self.get_binary_type_declaration_sql_snippet(column['length'], fixed) + return self.get_binary_type_declaration_sql_snippet(column["length"], fixed) def get_varchar_type_declaration_sql_snippet(self, length, fixed): - raise NotImplementedError('VARCHARS not supported by Platform') + raise NotImplementedError("VARCHARS not supported by Platform") def get_binary_type_declaration_sql_snippet(self, length, fixed): - raise NotImplementedError('BINARY/VARBINARY not supported by Platform') + raise NotImplementedError("BINARY/VARBINARY not supported by Platform") def get_decimal_type_declaration_sql(self, column): - if 'precision' not in column or not column['precision']: - column['precision'] = 10 + if "precision" not in column or not column["precision"]: + column["precision"] = 10 - if 'scale' not in column or not column['scale']: - column['precision'] = 0 + if "scale" not in column or not column["scale"]: + column["precision"] = 0 - return 'NUMERIC(%s, %s)' % (column['precision'], column['scale']) + return "NUMERIC(%s, %s)" % (column["precision"], column["scale"]) def get_json_type_declaration_sql(self, column): return self.get_clob_type_declaration_sql(column) @@ -387,7 +413,7 @@ def get_index_field_declaration_list_sql(self, fields): for field in fields: ret.append(field) - return ', '.join(ret) + return ", ".join(ret) def get_create_index_sql(self, index, table): """ @@ -413,9 +439,15 @@ def get_create_index_sql(self, index, table): if index.is_primary(): return self.get_create_primary_key_sql(index, table) - query = 'CREATE %sINDEX %s ON %s' % (self.get_create_index_sql_flags(index), name, table) - query += ' (%s)%s' % (self.get_index_field_declaration_list_sql(columns), - self.get_partial_index_sql(index)) + query = "CREATE %sINDEX %s ON %s" % ( + self.get_create_index_sql_flags(index), + name, + table, + ) + query += " (%s)%s" % ( + self.get_index_field_declaration_list_sql(columns), + self.get_partial_index_sql(index), + ) return query @@ -428,10 +460,10 @@ def get_partial_index_sql(self, index): :rtype: str """ - if self.supports_partial_indexes() and index.has_option('where'): - return ' WHERE %s' % index.get_option('where') + if self.supports_partial_indexes() and index.has_option("where"): + return " WHERE %s" % index.get_option("where") - return '' + return "" def get_create_index_sql_flags(self, index): """ @@ -443,9 +475,9 @@ def get_create_index_sql_flags(self, index): :rtype: str """ if index.is_unique(): - return 'UNIQUE ' + return "UNIQUE " - return '' + return "" def get_create_primary_key_sql(self, index, table): """ @@ -459,9 +491,10 @@ def get_create_primary_key_sql(self, index, table): :rtype: str """ - return 'ALTER TABLE %s ADD PRIMARY KEY (%s)'\ - % (table, - self.get_index_field_declaration_list_sql(index.get_quoted_columns(self))) + return "ALTER TABLE %s ADD PRIMARY KEY (%s)" % ( + table, + self.get_index_field_declaration_list_sql(index.get_quoted_columns(self)), + ) def get_create_foreign_key_sql(self, foreign_key, table): """ @@ -472,7 +505,10 @@ def get_create_foreign_key_sql(self, foreign_key, table): if isinstance(table, Table): table = table.get_quoted_name(self) - query = 'ALTER TABLE %s ADD %s' % (table, self.get_foreign_key_declaration_sql(foreign_key)) + query = "ALTER TABLE %s ADD %s" % ( + table, + self.get_foreign_key_declaration_sql(foreign_key), + ) return query @@ -488,7 +524,7 @@ def get_drop_table_sql(self, table): if isinstance(table, Table): table = table.get_quoted_name(self) - return 'DROP TABLE %s' % table + return "DROP TABLE %s" % table def get_drop_index_sql(self, index, table=None): """ @@ -505,7 +541,7 @@ def get_drop_index_sql(self, index, table=None): if isinstance(index, Index): index = index.get_quoted_name(self) - return 'DROP INDEX %s' % index + return "DROP INDEX %s" % index def get_create_table_sql(self, table, create_flags=CREATE_INDEXES): """ @@ -523,42 +559,42 @@ def get_create_table_sql(self, table, create_flags=CREATE_INDEXES): table_name = table.get_quoted_name(self) options = dict((k, v) for k, v in table.get_options().items()) - options['unique_constraints'] = OrderedDict() - options['indexes'] = OrderedDict() - options['primary'] = [] + options["unique_constraints"] = OrderedDict() + options["indexes"] = OrderedDict() + options["primary"] = [] if create_flags & self.CREATE_INDEXES > 0: for index in table.get_indexes().values(): if index.is_primary(): - options['primary'] = index.get_quoted_columns(self) - options['primary_index'] = index + options["primary"] = index.get_quoted_columns(self) + options["primary_index"] = index else: - options['indexes'][index.get_quoted_name(self)] = index + options["indexes"][index.get_quoted_name(self)] = index columns = OrderedDict() for column in table.get_columns().values(): column_data = column.to_dict() - column_data['name'] = column.get_quoted_name(self) - if column.has_platform_option('version'): - column_data['version'] = column.get_platform_option('version') + column_data["name"] = column.get_quoted_name(self) + if column.has_platform_option("version"): + column_data["version"] = column.get_platform_option("version") else: - column_data['version'] = False + column_data["version"] = False # column_data['comment'] = self.get_column_comment(column) - if column_data['type'] == 'string' and column_data['length'] is None: - column_data['length'] = 255 + if column_data["type"] == "string" and column_data["length"] is None: + column_data["length"] = 255 - if column.get_name() in options['primary']: - column_data['primary'] = True + if column.get_name() in options["primary"]: + column_data["primary"] = True - columns[column_data['name']] = column_data + columns[column_data["name"]] = column_data if create_flags & self.CREATE_FOREIGNKEYS > 0: - options['foreign_keys'] = [] + options["foreign_keys"] = [] for fk in table.get_foreign_keys().values(): - options['foreign_keys'].append(fk) + options["foreign_keys"].append(fk) sql = self._get_create_table_sql(table_name, columns, options) @@ -585,34 +621,37 @@ def _get_create_table_sql(self, table_name, columns, options=None): column_list_sql = self.get_column_declaration_list_sql(columns) - if options.get('unique_constraints'): - for name, definition in options['unique_constraints'].items(): - column_list_sql += ', %s' % self.get_unique_constraint_declaration_sql(name, definition) + if options.get("unique_constraints"): + for name, definition in options["unique_constraints"].items(): + column_list_sql += ", %s" % self.get_unique_constraint_declaration_sql( + name, definition + ) - if options.get('primary'): - column_list_sql += ', PRIMARY KEY(%s)' % ', '.join(options['primary']) + if options.get("primary"): + column_list_sql += ", PRIMARY KEY(%s)" % ", ".join(options["primary"]) - if options.get('indexes'): - for index, definition in options['indexes']: - column_list_sql += ', %s' % self.get_index_declaration_sql(index, definition) + if options.get("indexes"): + for index, definition in options["indexes"]: + column_list_sql += ", %s" % self.get_index_declaration_sql( + index, definition + ) - query = 'CREATE TABLE %s (%s' % (table_name, column_list_sql) + query = "CREATE TABLE %s (%s" % (table_name, column_list_sql) check = self.get_check_declaration_sql(columns) if check: - query += ', %s' % check + query += ", %s" % check - query += ')' + query += ")" sql = [query] - if options.get('foreign_keys'): - for definition in options['foreign_keys']: + if options.get("foreign_keys"): + for definition in options["foreign_keys"]: sql.append(self.get_create_foreign_key_sql(definition, table_name)) return sql - def quote_identifier(self, string): """ Quotes a string so that it can be safely used as a table or column name, @@ -625,10 +664,10 @@ def quote_identifier(self, string): :return: The quoted identifier string. :rtype: str """ - if '.' in string: - parts = list(map(self.quote_single_identifier, string.split('.'))) + if "." in string: + parts = list(map(self.quote_single_identifier, string.split("."))) - return '.'.join(parts) + return ".".join(parts) return self.quote_single_identifier(string) @@ -644,7 +683,7 @@ def quote_single_identifier(self, string): """ c = self.get_identifier_quote_character() - return '%s%s%s' % (c, string.replace(c, c+c), c) + return "%s%s%s" % (c, string.replace(c, c + c), c) def get_identifier_quote_character(self): return '"' diff --git a/orator/dbal/platforms/postgres_platform.py b/orator/dbal/platforms/postgres_platform.py index 0901e27d..af4c6d60 100644 --- a/orator/dbal/platforms/postgres_platform.py +++ b/orator/dbal/platforms/postgres_platform.py @@ -10,46 +10,46 @@ class PostgresPlatform(Platform): INTERNAL_TYPE_MAPPING = { - 'smallint': 'smallint', - 'int2': 'smallint', - 'serial': 'integer', - 'serial4': 'integer', - 'int': 'integer', - 'int4': 'integer', - 'integer': 'integer', - 'bigserial': 'bigint', - 'serial8': 'bigint', - 'bigint': 'bigint', - 'int8': 'bigint', - 'bool': 'boolean', - 'boolean': 'boolean', - 'text': 'text', - 'tsvector': 'text', - 'varchar': 'string', - 'interval': 'string', - '_varchar': 'string', - 'char': 'string', - 'bpchar': 'string', - 'inet': 'string', - 'date': 'date', - 'datetime': 'datetime', - 'timestamp': 'datetime', - 'timestamptz': 'datetimez', - 'time': 'time', - 'timetz': 'time', - 'float': 'float', - 'float4': 'float', - 'float8': 'float', - 'double': 'float', - 'double precision': 'float', - 'real': 'float', - 'decimal': 'decimal', - 'money': 'decimal', - 'numeric': 'decimal', - 'year': 'date', - 'uuid': 'guid', - 'bytea': 'blob', - 'json': 'json' + "smallint": "smallint", + "int2": "smallint", + "serial": "integer", + "serial4": "integer", + "int": "integer", + "int4": "integer", + "integer": "integer", + "bigserial": "bigint", + "serial8": "bigint", + "bigint": "bigint", + "int8": "bigint", + "bool": "boolean", + "boolean": "boolean", + "text": "text", + "tsvector": "text", + "varchar": "string", + "interval": "string", + "_varchar": "string", + "char": "string", + "bpchar": "string", + "inet": "string", + "date": "date", + "datetime": "datetime", + "timestamp": "datetime", + "timestamptz": "datetimez", + "time": "time", + "timetz": "time", + "float": "float", + "float4": "float", + "float8": "float", + "double": "float", + "double precision": "float", + "real": "float", + "decimal": "decimal", + "money": "decimal", + "numeric": "decimal", + "year": "date", + "uuid": "guid", + "bytea": "blob", + "json": "json", } def get_list_table_columns_sql(self, table): @@ -82,7 +82,9 @@ def get_list_table_columns_sql(self, table): AND a.attrelid = c.oid AND a.atttypid = t.oid AND n.oid = c.relnamespace - ORDER BY a.attnum""" % self.get_table_where_clause(table) + ORDER BY a.attnum""" % self.get_table_where_clause( + table + ) return sql @@ -99,57 +101,79 @@ def get_list_table_indexes_sql(self, table): AND sc.oid=si.indrelid AND sc.relnamespace = sn.oid ) AND pg_index.indexrelid = oid""" - sql = sql % self.get_table_where_clause(table, 'sc', 'sn') + sql = sql % self.get_table_where_clause(table, "sc", "sn") return sql def get_list_table_foreign_keys_sql(self, table): - return 'SELECT quote_ident(r.conname) as conname, ' \ - 'pg_catalog.pg_get_constraintdef(r.oid, true) AS condef ' \ - 'FROM pg_catalog.pg_constraint r ' \ - 'WHERE r.conrelid = ' \ - '(' \ - 'SELECT c.oid ' \ - 'FROM pg_catalog.pg_class c, pg_catalog.pg_namespace n ' \ - 'WHERE ' + self.get_table_where_clause(table) + ' AND n.oid = c.relnamespace' \ - ')' \ - ' AND r.contype = \'f\'' - - def get_table_where_clause(self, table, class_alias='c', namespace_alias='n'): - where_clause = namespace_alias + '.nspname NOT IN (\'pg_catalog\', \'information_schema\', \'pg_toast\') AND ' - if table.find('.') >= 0: - split = table.split('.') + return ( + "SELECT quote_ident(r.conname) as conname, " + "pg_catalog.pg_get_constraintdef(r.oid, true) AS condef " + "FROM pg_catalog.pg_constraint r " + "WHERE r.conrelid = " + "(" + "SELECT c.oid " + "FROM pg_catalog.pg_class c, pg_catalog.pg_namespace n " + "WHERE " + + self.get_table_where_clause(table) + + " AND n.oid = c.relnamespace" + ")" + " AND r.contype = 'f'" + ) + + def get_table_where_clause(self, table, class_alias="c", namespace_alias="n"): + where_clause = ( + namespace_alias + + ".nspname NOT IN ('pg_catalog', 'information_schema', 'pg_toast') AND " + ) + if table.find(".") >= 0: + split = table.split(".") schema, table = split[0], split[1] schema = "'%s'" % schema else: - schema = 'ANY(string_to_array((select replace(replace(setting, \'"$user"\', user), \' \', \'\')' \ - ' from pg_catalog.pg_settings where name = \'search_path\'),\',\'))' - - where_clause += '%s.relname = \'%s\' AND %s.nspname = %s' % (class_alias, table, namespace_alias, schema) + schema = ( + "ANY(string_to_array((select replace(replace(setting, '\"$user\"', user), ' ', '')" + " from pg_catalog.pg_settings where name = 'search_path'),','))" + ) + + where_clause += "%s.relname = '%s' AND %s.nspname = %s" % ( + class_alias, + table, + namespace_alias, + schema, + ) return where_clause def get_advanced_foreign_key_options_sql(self, foreign_key): - query = '' + query = "" - if foreign_key.has_option('match'): - query += ' MATCH %s' % foreign_key.get_option('match') + if foreign_key.has_option("match"): + query += " MATCH %s" % foreign_key.get_option("match") - query += super(PostgresPlatform, self).get_advanced_foreign_key_options_sql(foreign_key) + query += super(PostgresPlatform, self).get_advanced_foreign_key_options_sql( + foreign_key + ) - deferrable = foreign_key.has_option('deferrable') and foreign_key.get_option('deferrable') is not False + deferrable = ( + foreign_key.has_option("deferrable") + and foreign_key.get_option("deferrable") is not False + ) if deferrable: - query += ' DEFERRABLE' + query += " DEFERRABLE" else: - query += ' NOT DEFERRABLE' + query += " NOT DEFERRABLE" - query += ' INITIALLY' + query += " INITIALLY" - deferred = foreign_key.has_option('deferred') and foreign_key.get_option('deferred') is not False + deferred = ( + foreign_key.has_option("deferred") + and foreign_key.get_option("deferred") is not False + ) if deferred: - query += ' DEFERRED' + query += " DEFERRED" else: - query += ' IMMEDIATE' + query += " IMMEDIATE" return query @@ -171,58 +195,118 @@ def get_alter_table_sql(self, diff): old_column_name = column_diff.get_old_column_name().get_quoted_name(self) column = column_diff.column - if any([column_diff.has_changed('type'), - column_diff.has_changed('precision'), - column_diff.has_changed('scale'), - column_diff.has_changed('fixed')]): - query = 'ALTER ' + old_column_name + ' TYPE ' + self.get_sql_type_declaration(column.to_dict()) - sql.append('ALTER TABLE ' + diff.get_name(self).get_quoted_name(self) + ' ' + query) - - if column_diff.has_changed('default') or column_diff.has_changed('type'): + if any( + [ + column_diff.has_changed("type"), + column_diff.has_changed("precision"), + column_diff.has_changed("scale"), + column_diff.has_changed("fixed"), + ] + ): + query = ( + "ALTER " + + old_column_name + + " TYPE " + + self.get_sql_type_declaration(column.to_dict()) + ) + sql.append( + "ALTER TABLE " + + diff.get_name(self).get_quoted_name(self) + + " " + + query + ) + + if column_diff.has_changed("default") or column_diff.has_changed("type"): if column.get_default() is None: - default_clause = ' DROP DEFAULT' + default_clause = " DROP DEFAULT" else: - default_clause = ' SET' + self.get_default_value_declaration_sql(column.to_dict()) - - query = 'ALTER ' + old_column_name + default_clause - sql.append('ALTER TABLE ' + diff.get_name(self).get_quoted_name(self) + ' ' + query) - - if column_diff.has_changed('notnull'): - op = 'DROP' + default_clause = " SET" + self.get_default_value_declaration_sql( + column.to_dict() + ) + + query = "ALTER " + old_column_name + default_clause + sql.append( + "ALTER TABLE " + + diff.get_name(self).get_quoted_name(self) + + " " + + query + ) + + if column_diff.has_changed("notnull"): + op = "DROP" if column.get_notnull(): - op = 'SET' + op = "SET" - query = 'ALTER ' + old_column_name + ' ' + op + ' NOT NULL' - sql.append('ALTER TABLE ' + diff.get_name(self).get_quoted_name(self) + ' ' + query) + query = "ALTER " + old_column_name + " " + op + " NOT NULL" + sql.append( + "ALTER TABLE " + + diff.get_name(self).get_quoted_name(self) + + " " + + query + ) - if column_diff.has_changed('autoincrement'): + if column_diff.has_changed("autoincrement"): if column.get_autoincrement(): - seq_name = self.get_identity_sequence_name(diff.name, old_column_name) - - sql.append('CREATE SEQUENCE ' + seq_name) - sql.append('SELECT setval(\'' + seq_name + '\', ' - '(SELECT MAX(' + old_column_name + ') FROM ' + diff.name + '))') - query = 'ALTER ' + old_column_name + ' SET DEFAULT nextval(\'' + seq_name + '\')' - sql.append('ALTER TABLE ' + diff.get_name(self).get_quoted_name(self) + ' ' + query) + seq_name = self.get_identity_sequence_name( + diff.name, old_column_name + ) + + sql.append("CREATE SEQUENCE " + seq_name) + sql.append( + "SELECT setval('" + seq_name + "', " + "(SELECT MAX(" + old_column_name + ") FROM " + diff.name + "))" + ) + query = ( + "ALTER " + + old_column_name + + " SET DEFAULT nextval('" + + seq_name + + "')" + ) + sql.append( + "ALTER TABLE " + + diff.get_name(self).get_quoted_name(self) + + " " + + query + ) else: - query = 'ALTER ' + old_column_name + ' DROP DEFAULT' - sql.append('ALTER TABLE ' + diff.get_name(self).get_quoted_name(self) + ' ' + query) - - if column_diff.has_changed('length'): - query = 'ALTER ' + old_column_name + ' TYPE ' + self.get_sql_type_declaration(column.to_dict()) - sql.append('ALTER TABLE ' + diff.get_name(self).get_quoted_name(self) + ' ' + query) + query = "ALTER " + old_column_name + " DROP DEFAULT" + sql.append( + "ALTER TABLE " + + diff.get_name(self).get_quoted_name(self) + + " " + + query + ) + + if column_diff.has_changed("length"): + query = ( + "ALTER " + + old_column_name + + " TYPE " + + self.get_sql_type_declaration(column.to_dict()) + ) + sql.append( + "ALTER TABLE " + + diff.get_name(self).get_quoted_name(self) + + " " + + query + ) for old_column_name, column in diff.renamed_columns.items(): - sql.append('ALTER TABLE ' + diff.get_name(self).get_quoted_name(self) + ' ' - 'RENAME COLUMN ' + Identifier(old_column_name).get_quoted_name(self) + - ' TO ' + column.get_quoted_name(self)) + sql.append( + "ALTER TABLE " + diff.get_name(self).get_quoted_name(self) + " " + "RENAME COLUMN " + + Identifier(old_column_name).get_quoted_name(self) + + " TO " + + column.get_quoted_name(self) + ) return sql def is_unchanged_binary_column(self, column_diff): column_type = column_diff.column.get_type() - if column_type not in ['blob', 'binary']: + if column_type not in ["blob", "binary"]: return False if isinstance(column_diff.from_column, Column): @@ -233,15 +317,33 @@ def is_unchanged_binary_column(self, column_diff): if from_column: from_column_type = self.INTERNAL_TYPE_MAPPING[from_column.get_type()] - if from_column_type in ['blob', 'binary']: + if from_column_type in ["blob", "binary"]: return False - return len([x for x in column_diff.changed_properties if x not in ['type', 'length', 'fixed']]) == 0 - - if column_diff.has_changed('type'): + return ( + len( + [ + x + for x in column_diff.changed_properties + if x not in ["type", "length", "fixed"] + ] + ) + == 0 + ) + + if column_diff.has_changed("type"): return False - return len([x for x in column_diff.changed_properties if x not in ['length', 'fixed']]) == 0 + return ( + len( + [ + x + for x in column_diff.changed_properties + if x not in ["length", "fixed"] + ] + ) + == 0 + ) def convert_booleans(self, item): if isinstance(item, list): @@ -254,73 +356,73 @@ def convert_booleans(self, item): return item def get_boolean_type_declaration_sql(self, column): - return 'BOOLEAN' + return "BOOLEAN" def get_integer_type_declaration_sql(self, column): - if column.get('autoincrement'): - return 'SERIAL' + if column.get("autoincrement"): + return "SERIAL" - return 'INT' + return "INT" def get_bigint_type_declaration_sql(self, column): - if column.get('autoincrement'): - return 'BIGSERIAL' + if column.get("autoincrement"): + return "BIGSERIAL" - return 'BIGINT' + return "BIGINT" def get_smallint_type_declaration_sql(self, column): - return 'SMALLINT' + return "SMALLINT" def get_guid_type_declaration_sql(self, column): - return 'UUID' + return "UUID" def get_datetime_type_declaration_sql(self, column): - return 'TIMESTAMP(0) WITHOUT TIME ZONE' + return "TIMESTAMP(0) WITHOUT TIME ZONE" def get_datetimetz_type_declaration_sql(self, column): - return 'TIMESTAMP(0) WITH TIME ZONE' + return "TIMESTAMP(0) WITH TIME ZONE" def get_date_type_declaration_sql(self, column): - return 'DATE' + return "DATE" def get_time_type_declaration_sql(self, column): - return 'TIME(0) WITHOUT TIME ZONE' + return "TIME(0) WITHOUT TIME ZONE" def get_string_type_declaration_sql(self, column): - length = column.get('length', '255') - fixed = column.get('fixed') + length = column.get("length", "255") + fixed = column.get("fixed") if fixed: - return 'CHAR(%s)' % length + return "CHAR(%s)" % length else: - return 'VARCHAR(%s)' % length + return "VARCHAR(%s)" % length def get_binary_type_declaration_sql(self, column): - return 'BYTEA' + return "BYTEA" def get_blob_type_declaration_sql(self, column): - return 'BYTEA' + return "BYTEA" def get_clob_type_declaration_sql(self, column): - return 'TEXT' + return "TEXT" def get_text_type_declaration_sql(self, column): - return 'TEXT' + return "TEXT" def get_json_type_declaration_sql(self, column): - return 'JSON' + return "JSON" def get_decimal_type_declaration_sql(self, column): - if 'precision' not in column or not column['precision']: - column['precision'] = 10 + if "precision" not in column or not column["precision"]: + column["precision"] = 10 - if 'scale' not in column or not column['scale']: - column['precision'] = 0 + if "scale" not in column or not column["scale"]: + column["precision"] = 0 - return 'DECIMAL(%s, %s)' % (column['precision'], column['scale']) + return "DECIMAL(%s, %s)" % (column["precision"], column["scale"]) def get_float_type_declaration_sql(self, column): - return 'DOUBLE PRECISION' + return "DOUBLE PRECISION" def supports_foreign_key_constraints(self): return True diff --git a/orator/dbal/platforms/sqlite_platform.py b/orator/dbal/platforms/sqlite_platform.py index 05540d50..b5698859 100644 --- a/orator/dbal/platforms/sqlite_platform.py +++ b/orator/dbal/platforms/sqlite_platform.py @@ -14,54 +14,54 @@ class SQLitePlatform(Platform): INTERNAL_TYPE_MAPPING = { - 'boolean': 'boolean', - 'tinyint': 'boolean', - 'smallint': 'smallint', - 'mediumint': 'integer', - 'int': 'integer', - 'integer': 'integer', - 'serial': 'integer', - 'bigint': 'bigint', - 'bigserial': 'bigint', - 'clob': 'text', - 'tinytext': 'text', - 'mediumtext': 'text', - 'longtext': 'text', - 'text': 'text', - 'varchar': 'string', - 'longvarchar': 'string', - 'varchar2': 'string', - 'nvarchar': 'string', - 'image': 'string', - 'ntext': 'string', - 'char': 'string', - 'date': 'date', - 'datetime': 'datetime', - 'timestamp': 'datetime', - 'time': 'time', - 'float': 'float', - 'double': 'float', - 'double precision': 'float', - 'real': 'float', - 'decimal': 'decimal', - 'numeric': 'decimal', - 'blob': 'blob', + "boolean": "boolean", + "tinyint": "boolean", + "smallint": "smallint", + "mediumint": "integer", + "int": "integer", + "integer": "integer", + "serial": "integer", + "bigint": "bigint", + "bigserial": "bigint", + "clob": "text", + "tinytext": "text", + "mediumtext": "text", + "longtext": "text", + "text": "text", + "varchar": "string", + "longvarchar": "string", + "varchar2": "string", + "nvarchar": "string", + "image": "string", + "ntext": "string", + "char": "string", + "date": "date", + "datetime": "datetime", + "timestamp": "datetime", + "time": "time", + "float": "float", + "double": "float", + "double precision": "float", + "real": "float", + "decimal": "decimal", + "numeric": "decimal", + "blob": "blob", } def get_list_table_columns_sql(self, table): - table = table.replace('.', '__') + table = table.replace(".", "__") - return 'PRAGMA table_info(\'%s\')' % table + return "PRAGMA table_info('%s')" % table def get_list_table_indexes_sql(self, table): - table = table.replace('.', '__') + table = table.replace(".", "__") - return 'PRAGMA index_list(\'%s\')' % table + return "PRAGMA index_list('%s')" % table def get_list_table_foreign_keys_sql(self, table): - table = table.replace('.', '__') + table = table.replace(".", "__") - return 'PRAGMA foreign_key_list(\'%s\')' % table + return "PRAGMA foreign_key_list('%s')" % table def get_pre_alter_table_index_foreign_key_sql(self, diff): """ @@ -71,8 +71,10 @@ def get_pre_alter_table_index_foreign_key_sql(self, diff): :rtype: list """ if not isinstance(diff.from_table, Table): - raise DBALException('Sqlite platform requires for alter table the table' - 'diff with reference to original table schema') + raise DBALException( + "Sqlite platform requires for alter table the table" + "diff with reference to original table schema" + ) sql = [] for index in diff.from_table.get_indexes().values(): @@ -89,8 +91,10 @@ def get_post_alter_table_index_foreign_key_sql(self, diff): :rtype: list """ if not isinstance(diff.from_table, Table): - raise DBALException('Sqlite platform requires for alter table the table' - 'diff with reference to original table schema') + raise DBALException( + "Sqlite platform requires for alter table the table" + "diff with reference to original table schema" + ) sql = [] @@ -103,7 +107,9 @@ def get_post_alter_table_index_foreign_key_sql(self, diff): if index.is_primary(): continue - sql.append(self.get_create_index_sql(index, table_name.get_quoted_name(self))) + sql.append( + self.get_create_index_sql(index, table_name.get_quoted_name(self)) + ) return sql @@ -114,34 +120,36 @@ def get_create_table_sql(self, table, create_flags=None): return super(SQLitePlatform, self).get_create_table_sql(table, create_flags) def _get_create_table_sql(self, table_name, columns, options=None): - table_name = table_name.replace('.', '__') + table_name = table_name.replace(".", "__") query_fields = self.get_column_declaration_list_sql(columns) - if options.get('unique_constraints'): - for name, definition in options['unique_constraints'].items(): - query_fields += ', %s' % self.get_unique_constraint_declaration_sql(name, definition) + if options.get("unique_constraints"): + for name, definition in options["unique_constraints"].items(): + query_fields += ", %s" % self.get_unique_constraint_declaration_sql( + name, definition + ) - if options.get('primary'): - key_columns = options['primary'] - query_fields += ', PRIMARY KEY(%s)' % ', '.join(key_columns) + if options.get("primary"): + key_columns = options["primary"] + query_fields += ", PRIMARY KEY(%s)" % ", ".join(key_columns) - if options.get('foreign_keys'): - for foreign_key in options['foreign_keys']: - query_fields += ', %s' % self.get_foreign_key_declaration_sql(foreign_key) + if options.get("foreign_keys"): + for foreign_key in options["foreign_keys"]: + query_fields += ", %s" % self.get_foreign_key_declaration_sql( + foreign_key + ) - query = [ - 'CREATE TABLE %s (%s)' % (table_name, query_fields) - ] + query = ["CREATE TABLE %s (%s)" % (table_name, query_fields)] - if options.get('alter'): + if options.get("alter"): return query - if options.get('indexes'): - for index_def in options['indexes'].values(): + if options.get("indexes"): + for index_def in options["indexes"].values(): query.append(self.get_create_index_sql(index_def, table_name)) - if options.get('unique'): - for index_def in options['unique'].values(): + if options.get("unique"): + for index_def in options["unique"].values(): query.append(self.get_create_index_sql(index_def, table_name)) return query @@ -150,29 +158,37 @@ def get_foreign_key_declaration_sql(self, foreign_key): return super(SQLitePlatform, self).get_foreign_key_declaration_sql( ForeignKeyConstraint( foreign_key.get_quoted_local_columns(self), - foreign_key.get_quoted_foreign_table_name(self).replace('.', '__'), + foreign_key.get_quoted_foreign_table_name(self).replace(".", "__"), foreign_key.get_quoted_foreign_columns(self), foreign_key.get_name(), - foreign_key.get_options() + foreign_key.get_options(), ) ) def get_advanced_foreign_key_options_sql(self, foreign_key): - query = super(SQLitePlatform, self).get_advanced_foreign_key_options_sql(foreign_key) + query = super(SQLitePlatform, self).get_advanced_foreign_key_options_sql( + foreign_key + ) - deferrable = foreign_key.has_option('deferrable') and foreign_key.get_option('deferrable') is not False + deferrable = ( + foreign_key.has_option("deferrable") + and foreign_key.get_option("deferrable") is not False + ) if deferrable: - query += ' DEFERRABLE' + query += " DEFERRABLE" else: - query += ' NOT DEFERRABLE' + query += " NOT DEFERRABLE" - query += ' INITIALLY' + query += " INITIALLY" - deferred = foreign_key.has_option('deferred') and foreign_key.get_option('deferred') is not False + deferred = ( + foreign_key.has_option("deferred") + and foreign_key.get_option("deferred") is not False + ) if deferred: - query += ' DEFERRED' + query += " DEFERRED" else: - query += ' IMMEDIATE' + query += " IMMEDIATE" return query @@ -192,8 +208,8 @@ def get_alter_table_sql(self, diff): from_table = diff.from_table if not isinstance(from_table, Table): raise DBALException( - 'SQLite platform requires for the alter table the table diff ' - 'referencing the original table' + "SQLite platform requires for the alter table the table diff " + "referencing the original table" ) table = from_table.clone() @@ -231,33 +247,46 @@ def get_alter_table_sql(self, diff): columns[column_diff.column.get_name().lower()] = column_diff.column if old_column_name in new_column_names: - new_column_names[old_column_name] = column_diff.column.get_quoted_name(self) + new_column_names[old_column_name] = column_diff.column.get_quoted_name( + self + ) for column_name, column in diff.added_columns.items(): columns[column_name.lower()] = column table_sql = [] - data_table = Table('__temp__' + table.get_name()) - new_table = Table(table.get_quoted_name(self), columns, - self._get_primary_index_in_altered_table(diff), - self._get_foreign_keys_in_altered_table(diff), - table.get_options()) - new_table.add_option('alter', True) + data_table = Table("__temp__" + table.get_name()) + new_table = Table( + table.get_quoted_name(self), + columns, + self._get_primary_index_in_altered_table(diff), + self._get_foreign_keys_in_altered_table(diff), + table.get_options(), + ) + new_table.add_option("alter", True) sql = self.get_pre_alter_table_index_foreign_key_sql(diff) - sql.append('CREATE TEMPORARY TABLE %s AS SELECT %s FROM %s' - % (data_table.get_quoted_name(self), - ', '.join(old_column_names.values()), - table.get_quoted_name(self))) + sql.append( + "CREATE TEMPORARY TABLE %s AS SELECT %s FROM %s" + % ( + data_table.get_quoted_name(self), + ", ".join(old_column_names.values()), + table.get_quoted_name(self), + ) + ) sql.append(self.get_drop_table_sql(from_table)) sql += self.get_create_table_sql(new_table) - sql.append('INSERT INTO %s (%s) SELECT %s FROM %s' - % (new_table.get_quoted_name(self), - ', '.join(new_column_names.values()), - ', '.join(old_column_names.values()), - data_table.get_name())) + sql.append( + "INSERT INTO %s (%s) SELECT %s FROM %s" + % ( + new_table.get_quoted_name(self), + ", ".join(new_column_names.values()), + ", ".join(old_column_names.values()), + data_table.get_name(), + ) + ) sql.append(self.get_drop_table_sql(data_table)) sql += self.get_post_alter_table_index_foreign_key_sql(diff) @@ -266,26 +295,40 @@ def get_alter_table_sql(self, diff): def _get_simple_alter_table_sql(self, diff): for old_column_name, column_diff in diff.changed_columns.items(): - if not isinstance(column_diff.from_column, Column)\ - or not isinstance(column_diff.column, Column)\ - or not column_diff.column.get_autoincrement()\ - or column_diff.column.get_type().lower() != 'integer': + if ( + not isinstance(column_diff.from_column, Column) + or not isinstance(column_diff.column, Column) + or not column_diff.column.get_autoincrement() + or column_diff.column.get_type().lower() != "integer" + ): continue - if not column_diff.has_changed('type') and not column_diff.has_changed('unsigned'): + if not column_diff.has_changed("type") and not column_diff.has_changed( + "unsigned" + ): del diff.changed_columns[old_column_name] continue from_column_type = column_diff.column.get_type() - if from_column_type == 'smallint' or from_column_type == 'bigint': + if from_column_type == "smallint" or from_column_type == "bigint": del diff.changed_columns[old_column_name] - if any([not diff.renamed_columns, not diff.added_foreign_keys, not diff.added_indexes, - not diff.changed_columns, not diff.changed_foreign_keys, not diff.changed_indexes, - not diff.removed_columns, not diff.removed_foreign_keys, not diff.removed_indexes, - not diff.renamed_indexes]): + if any( + [ + not diff.renamed_columns, + not diff.added_foreign_keys, + not diff.added_indexes, + not diff.changed_columns, + not diff.changed_foreign_keys, + not diff.changed_indexes, + not diff.removed_columns, + not diff.removed_foreign_keys, + not diff.removed_indexes, + not diff.renamed_indexes, + ] + ): return False table = Table(diff.name) @@ -295,34 +338,45 @@ def _get_simple_alter_table_sql(self, diff): column_sql = [] for column in diff.added_columns.values(): - field = { - 'unique': None, - 'autoincrement': None, - 'default': None - } + field = {"unique": None, "autoincrement": None, "default": None} field.update(column.to_dict()) - type_ = field['type'] - if 'column_definition' in field or field['autoincrement'] or field['unique']: + type_ = field["type"] + if ( + "column_definition" in field + or field["autoincrement"] + or field["unique"] + ): return False - elif type_ == 'datetime' and field['default'] == self.get_current_timestamp_sql(): + elif ( + type_ == "datetime" + and field["default"] == self.get_current_timestamp_sql() + ): return False - elif type_ == 'date' and field['default'] == self.get_current_date_sql(): + elif type_ == "date" and field["default"] == self.get_current_date_sql(): return False - elif type_ == 'time' and field['default'] == self.get_current_time_sql(): + elif type_ == "time" and field["default"] == self.get_current_time_sql(): return False - field['name'] = column.get_quoted_name(self) - if field['type'].lower() == 'string' and field['length'] is None: - field['length'] = 255 + field["name"] = column.get_quoted_name(self) + if field["type"].lower() == "string" and field["length"] is None: + field["length"] = 255 - sql.append('ALTER TABLE ' + table.get_quoted_name(self) + - ' ADD COLUMN ' + self.get_column_declaration_sql(field['name'], field)) + sql.append( + "ALTER TABLE " + + table.get_quoted_name(self) + + " ADD COLUMN " + + self.get_column_declaration_sql(field["name"], field) + ) if diff.new_name is not False: new_table = Identifier(diff.new_name) - sql.append('ALTER TABLE ' + table.get_quoted_name(self) + - ' RENAME TO ' + new_table.get_quoted_name(self)) + sql.append( + "ALTER TABLE " + + table.get_quoted_name(self) + + " RENAME TO " + + new_table.get_quoted_name(self) + ) return sql @@ -354,9 +408,13 @@ def _get_indexes_in_altered_table(self, diff): changed = True if changed: - indexes[key] = Index(index.get_name(), index_columns, - index.is_unique(), index.is_primary(), - index.get_flags()) + indexes[key] = Index( + index.get_name(), + index_columns, + index.is_unique(), + index.is_primary(), + index.get_flags(), + ) for index in diff.removed_indexes.values(): index_name = index.get_name().lower() @@ -438,7 +496,7 @@ def _get_foreign_keys_in_altered_table(self, diff): constraint.get_foreign_table_name(), constraint.get_foreign_columns(), constraint.get_name(), - constraint.get_options() + constraint.get_options(), ) for constraint in diff.removed_foreign_keys: @@ -475,72 +533,72 @@ def supports_foreign_key_constraints(self): return True def get_boolean_type_declaration_sql(self, column): - return 'BOOLEAN' + return "BOOLEAN" def get_integer_type_declaration_sql(self, column): - return 'INTEGER' + self._get_common_integer_type_declaration_sql(column) + return "INTEGER" + self._get_common_integer_type_declaration_sql(column) def get_bigint_type_declaration_sql(self, column): # SQLite autoincrement is implicit for INTEGER PKs, but not for BIGINT fields. - if not column.get('autoincrement', False): + if not column.get("autoincrement", False): return self.get_integer_type_declaration_sql(column) - return 'BIGINT' + self._get_common_integer_type_declaration_sql(column) + return "BIGINT" + self._get_common_integer_type_declaration_sql(column) def get_tinyint_type_declaration_sql(self, column): # SQLite autoincrement is implicit for INTEGER PKs, but not for TINYINT fields. - if not column.get('autoincrement', False): + if not column.get("autoincrement", False): return self.get_integer_type_declaration_sql(column) - return 'TINYINT' + self._get_common_integer_type_declaration_sql(column) + return "TINYINT" + self._get_common_integer_type_declaration_sql(column) def get_smallint_type_declaration_sql(self, column): # SQLite autoincrement is implicit for INTEGER PKs, but not for SMALLINT fields. - if not column.get('autoincrement', False): + if not column.get("autoincrement", False): return self.get_integer_type_declaration_sql(column) - return 'SMALLINT' + self._get_common_integer_type_declaration_sql(column) + return "SMALLINT" + self._get_common_integer_type_declaration_sql(column) def get_mediumint_type_declaration_sql(self, column): # SQLite autoincrement is implicit for INTEGER PKs, but not for MEDIUMINT fields. - if not column.get('autoincrement', False): + if not column.get("autoincrement", False): return self.get_integer_type_declaration_sql(column) - return 'MEDIUMINT' + self._get_common_integer_type_declaration_sql(column) + return "MEDIUMINT" + self._get_common_integer_type_declaration_sql(column) def get_datetime_type_declaration_sql(self, column): - return 'DATETIME' + return "DATETIME" def get_date_type_declaration_sql(self, column): - return 'DATE' + return "DATE" def get_time_type_declaration_sql(self, column): - return 'TIME' + return "TIME" def _get_common_integer_type_declaration_sql(self, column): # sqlite autoincrement is implicit for integer PKs, but not when the field is unsigned - if not column.get('autoincrement', False): - return '' + if not column.get("autoincrement", False): + return "" - if not column.get('unsigned', False): - return ' UNSIGNED' + if not column.get("unsigned", False): + return " UNSIGNED" - return '' + return "" def get_varchar_type_declaration_sql_snippet(self, length, fixed): if fixed: - return 'CHAR(%s)' % length if length else 'CHAR(255)' + return "CHAR(%s)" % length if length else "CHAR(255)" else: - return 'VARCHAR(%s)' % length if length else 'TEXT' + return "VARCHAR(%s)" % length if length else "TEXT" def get_blob_type_declaration_sql(self, column): - return 'BLOB' + return "BLOB" def get_clob_type_declaration_sql(self, column): - return 'CLOB' + return "CLOB" def get_column_options(self): - return ['pk'] + return ["pk"] def _get_reserved_keywords_class(self): return SQLiteKeywords diff --git a/orator/dbal/postgres_schema_manager.py b/orator/dbal/postgres_schema_manager.py index 1390bcc0..e2724a3d 100644 --- a/orator/dbal/postgres_schema_manager.py +++ b/orator/dbal/postgres_schema_manager.py @@ -7,98 +7,110 @@ class PostgresSchemaManager(SchemaManager): - def _get_portable_table_column_definition(self, table_column): - if table_column['type'].lower() == 'varchar' or table_column['type'] == 'bpchar': - length = re.sub('.*\(([0-9]*)\).*', '\\1', table_column['complete_type']) - table_column['length'] = length + if ( + table_column["type"].lower() == "varchar" + or table_column["type"] == "bpchar" + ): + length = re.sub(".*\(([0-9]*)\).*", "\\1", table_column["complete_type"]) + table_column["length"] = length autoincrement = False - match = re.match("^nextval\('?(.*)'?(::.*)?\)$", str(table_column['default'])) + match = re.match("^nextval\('?(.*)'?(::.*)?\)$", str(table_column["default"])) if match: - table_column['sequence'] = match.group(1) - table_column['default'] = None + table_column["sequence"] = match.group(1) + table_column["default"] = None autoincrement = True - match = re.match("^'?([^']*)'?::.*$", str(table_column['default'])) + match = re.match("^'?([^']*)'?::.*$", str(table_column["default"])) if match: - table_column['default'] = match.group(1) + table_column["default"] = match.group(1) - if str(table_column['default']).find('NULL') == 0: - table_column['default'] = None + if str(table_column["default"]).find("NULL") == 0: + table_column["default"] = None - if 'length' in table_column: - length = table_column['length'] + if "length" in table_column: + length = table_column["length"] else: length = None - if length == '-1' and 'atttypmod' in table_column: - length = table_column['atttypmod'] - 4 + if length == "-1" and "atttypmod" in table_column: + length = table_column["atttypmod"] - 4 if length is None or not length.isdigit() or int(length) <= 0: length = None fixed = None - if 'name' not in table_column: - table_column['name'] = '' + if "name" not in table_column: + table_column["name"] = "" precision = None scale = None - db_type = table_column['type'].lower() + db_type = table_column["type"].lower() type = self._platform.get_type_mapping(db_type) - if db_type in ['smallint', 'int2']: + if db_type in ["smallint", "int2"]: length = None - elif db_type in ['int', 'int4', 'integer']: + elif db_type in ["int", "int4", "integer"]: length = None - elif db_type in ['int8', 'bigint']: + elif db_type in ["int8", "bigint"]: length = None - elif db_type in ['bool', 'boolean']: - if table_column['default'] == 'true': - table_column['default'] = True + elif db_type in ["bool", "boolean"]: + if table_column["default"] == "true": + table_column["default"] = True - if table_column['default'] == 'false': - table_column['default'] = False + if table_column["default"] == "false": + table_column["default"] = False length = None - elif db_type == 'text': + elif db_type == "text": fixed = False - elif db_type in ['varchar', 'interval', '_varchar']: + elif db_type in ["varchar", "interval", "_varchar"]: fixed = False - elif db_type in ['char', 'bpchar']: + elif db_type in ["char", "bpchar"]: fixed = True - elif db_type in ['float', 'float4', 'float8', - 'double', 'double precision', - 'real', 'decimal', 'money', 'numeric']: - match = re.match('([A-Za-z]+\(([0-9]+),([0-9]+)\))', table_column['complete_type']) + elif db_type in [ + "float", + "float4", + "float8", + "double", + "double precision", + "real", + "decimal", + "money", + "numeric", + ]: + match = re.match( + "([A-Za-z]+\(([0-9]+),([0-9]+)\))", table_column["complete_type"] + ) if match: precision = match.group(1) scale = match.group(2) length = None - elif db_type == 'year': + elif db_type == "year": length = None - if table_column['default']: - match = re.match("('?([^']+)'?::)", str(table_column['default'])) + if table_column["default"]: + match = re.match("('?([^']+)'?::)", str(table_column["default"])) if match: - table_column['default'] = match.group(1) + table_column["default"] = match.group(1) options = { - 'length': length, - 'notnull': table_column['isnotnull'], - 'default': table_column['default'], - 'primary': table_column['pri'] == 't', - 'precision': precision, - 'scale': scale, - 'fixed': fixed, - 'unsigned': False, - 'autoincrement': autoincrement + "length": length, + "notnull": table_column["isnotnull"], + "default": table_column["default"], + "primary": table_column["pri"] == "t", + "precision": precision, + "scale": scale, + "fixed": fixed, + "unsigned": False, + "autoincrement": autoincrement, } - column = Column(table_column['field'], type, options) + column = Column(table_column["field"], type, options) return column @@ -106,48 +118,64 @@ def _get_portable_table_indexes_list(self, table_indexes, table_name): buffer = [] for row in table_indexes: - col_numbers = row['indkey'].split(' ') - col_numbers_sql = 'IN (%s)' % ', '.join(col_numbers) - column_name_sql = 'SELECT attnum, attname FROM pg_attribute ' \ - 'WHERE attrelid=%s AND attnum %s ORDER BY attnum ASC;'\ - % (row['indrelid'], col_numbers_sql) + col_numbers = row["indkey"].split(" ") + col_numbers_sql = "IN (%s)" % ", ".join(col_numbers) + column_name_sql = ( + "SELECT attnum, attname FROM pg_attribute " + "WHERE attrelid=%s AND attnum %s ORDER BY attnum ASC;" + % (row["indrelid"], col_numbers_sql) + ) index_columns = self._connection.select(column_name_sql) # required for getting the order of the columns right. for col_num in col_numbers: for col_row in index_columns: - if int(col_num) == col_row['attnum']: - buffer.append({ - 'key_name': row['relname'], - 'column_name': col_row['attname'].strip(), - 'non_unique': not row['indisunique'], - 'primary': row['indisprimary'], - 'where': row['where'] - }) - - return super(PostgresSchemaManager, self)._get_portable_table_indexes_list(buffer, table_name) + if int(col_num) == col_row["attnum"]: + buffer.append( + { + "key_name": row["relname"], + "column_name": col_row["attname"].strip(), + "non_unique": not row["indisunique"], + "primary": row["indisprimary"], + "where": row["where"], + } + ) + + return super(PostgresSchemaManager, self)._get_portable_table_indexes_list( + buffer, table_name + ) def _get_portable_table_foreign_key_definition(self, table_foreign_key): - on_update = '' - on_delete = '' + on_update = "" + on_delete = "" - match = re.match('ON UPDATE ([a-zA-Z0-9]+( (NULL|ACTION|DEFAULT))?)', table_foreign_key['condef']) + match = re.match( + "ON UPDATE ([a-zA-Z0-9]+( (NULL|ACTION|DEFAULT))?)", + table_foreign_key["condef"], + ) if match: on_update = match.group(1) - match = re.match('ON DELETE ([a-zA-Z0-9]+( (NULL|ACTION|DEFAULT))?)', table_foreign_key['condef']) + match = re.match( + "ON DELETE ([a-zA-Z0-9]+( (NULL|ACTION|DEFAULT))?)", + table_foreign_key["condef"], + ) if match: on_delete = match.group(1) - values = re.match('FOREIGN KEY \((.+)\) REFERENCES (.+)\((.+)\)', table_foreign_key['condef']) + values = re.match( + "FOREIGN KEY \((.+)\) REFERENCES (.+)\((.+)\)", table_foreign_key["condef"] + ) if values: - local_columns = [c.strip() for c in values.group(1).split(',')] - foreign_columns = [c.strip() for c in values.group(3).split(',')] + local_columns = [c.strip() for c in values.group(1).split(",")] + foreign_columns = [c.strip() for c in values.group(3).split(",")] foreign_table = values.group(2) return ForeignKeyConstraint( - local_columns, foreign_table, foreign_columns, - table_foreign_key['conname'], - {'on_update': on_update, 'on_delete': on_delete} + local_columns, + foreign_table, + foreign_columns, + table_foreign_key["conname"], + {"on_update": on_update, "on_delete": on_delete}, ) diff --git a/orator/dbal/schema_manager.py b/orator/dbal/schema_manager.py index 83a464d3..df9d913b 100644 --- a/orator/dbal/schema_manager.py +++ b/orator/dbal/schema_manager.py @@ -7,7 +7,6 @@ class SchemaManager(object): - def __init__(self, connection, platform=None): """ :param connection: The connection to use @@ -78,35 +77,38 @@ def _get_portable_table_indexes_list(self, table_indexes, table_name): result = OrderedDict() for table_index in table_indexes: - index_name = table_index['key_name'] - key_name = table_index['key_name'] - if table_index['primary']: - key_name = 'primary' + index_name = table_index["key_name"] + key_name = table_index["key_name"] + if table_index["primary"]: + key_name = "primary" key_name = key_name.lower() if key_name not in result: options = {} - if 'where' in table_index: - options['where'] = table_index['where'] + if "where" in table_index: + options["where"] = table_index["where"] result[key_name] = { - 'name': index_name, - 'columns': [table_index['column_name']], - 'unique': not table_index['non_unique'], - 'primary': table_index['primary'], - 'flags': table_index.get('flags') or None, - 'options': options + "name": index_name, + "columns": [table_index["column_name"]], + "unique": not table_index["non_unique"], + "primary": table_index["primary"], + "flags": table_index.get("flags") or None, + "options": options, } else: - result[key_name]['columns'].append(table_index['column_name']) + result[key_name]["columns"].append(table_index["column_name"]) indexes = OrderedDict() for index_key, data in result.items(): index = Index( - data['name'], data['columns'], - data['unique'], data['primary'], - data['flags'], data['options'] + data["name"], + data["columns"], + data["unique"], + data["primary"], + data["flags"], + data["options"], ) indexes[index_key] = index diff --git a/orator/dbal/sqlite_schema_manager.py b/orator/dbal/sqlite_schema_manager.py index 0db7eb1f..844dc105 100644 --- a/orator/dbal/sqlite_schema_manager.py +++ b/orator/dbal/sqlite_schema_manager.py @@ -8,67 +8,68 @@ class SQLiteSchemaManager(SchemaManager): - def _get_portable_table_column_definition(self, table_column): - parts = table_column['type'].split('(') - table_column['type'] = parts[0] + parts = table_column["type"].split("(") + table_column["type"] = parts[0] if len(parts) > 1: - length = parts[1].strip(')') - table_column['length'] = length + length = parts[1].strip(")") + table_column["length"] = length - db_type = table_column['type'].lower() - length = table_column.get('length', None) + db_type = table_column["type"].lower() + length = table_column.get("length", None) unsigned = False - if ' unsigned' in db_type: - db_type = db_type.replace(' unsigned', '') + if " unsigned" in db_type: + db_type = db_type.replace(" unsigned", "") unsigned = True fixed = False type = self._platform.get_type_mapping(db_type) - default = table_column['dflt_value'] - if default == 'NULL': + default = table_column["dflt_value"] + if default == "NULL": default = None if default is not None: # SQLite returns strings wrapped in single quotes, so we need to strip them - default = re.sub("^'(.*)'$", '\\1', default) + default = re.sub("^'(.*)'$", "\\1", default) - notnull = bool(table_column['notnull']) + notnull = bool(table_column["notnull"]) - if 'name' not in table_column: - table_column['name'] = '' + if "name" not in table_column: + table_column["name"] = "" precision = None scale = None - if db_type in ['char']: + if db_type in ["char"]: fixed = True - elif db_type in ['varchar']: + elif db_type in ["varchar"]: length = length or 255 - elif db_type in ['float', 'double', 'real', 'decimal', 'numeric']: - if 'length' in table_column: - if ',' not in table_column['length']: - table_column['length'] += ',0' + elif db_type in ["float", "double", "real", "decimal", "numeric"]: + if "length" in table_column: + if "," not in table_column["length"]: + table_column["length"] += ",0" - precision, scale = tuple(map(lambda x: x.strip(), table_column['length'].split(','))) + precision, scale = tuple( + map(lambda x: x.strip(), table_column["length"].split(",")) + ) length = None options = { - 'length': length, - 'unsigned': bool(unsigned), - 'fixed': fixed, - 'notnull': notnull, - 'default': default, - 'precision': precision, - 'scale': scale, - 'autoincrement': False + "length": length, + "unsigned": bool(unsigned), + "fixed": fixed, + "notnull": notnull, + "default": default, + "precision": precision, + "scale": scale, + "autoincrement": False, } - column = Column(table_column['name'], type, options) - column.set_platform_option('pk', table_column['pk']) + column = Column(table_column["name"], type, options) + column.set_platform_option("pk", table_column["pk"]) return column @@ -76,78 +77,84 @@ def _get_portable_table_indexes_list(self, table_indexes, table_name): index_buffer = [] # Fetch primary - info = self._connection.select('PRAGMA TABLE_INFO (%s)' % table_name) + info = self._connection.select("PRAGMA TABLE_INFO (%s)" % table_name) for row in info: - if row['pk'] != 0: - index_buffer.append({ - 'key_name': 'primary', - 'primary': True, - 'non_unique': False, - 'column_name': row['name'] - }) + if row["pk"] != 0: + index_buffer.append( + { + "key_name": "primary", + "primary": True, + "non_unique": False, + "column_name": row["name"], + } + ) # Fetch regular indexes for index in table_indexes: # Ignore indexes with reserved names, e.g. autoindexes - if index['name'].find('sqlite_') == -1: - key_name = index['name'] + if index["name"].find("sqlite_") == -1: + key_name = index["name"] idx = { - 'key_name': key_name, - 'primary': False, - 'non_unique': not bool(index['unique']) + "key_name": key_name, + "primary": False, + "non_unique": not bool(index["unique"]), } - info = self._connection.select('PRAGMA INDEX_INFO (\'%s\')' % key_name) + info = self._connection.select("PRAGMA INDEX_INFO ('%s')" % key_name) for row in info: - idx['column_name'] = row['name'] + idx["column_name"] = row["name"] index_buffer.append(idx) - return super(SQLiteSchemaManager, self)._get_portable_table_indexes_list(index_buffer, table_name) + return super(SQLiteSchemaManager, self)._get_portable_table_indexes_list( + index_buffer, table_name + ) def _get_portable_table_foreign_keys_list(self, table_foreign_keys): foreign_keys = OrderedDict() for value in table_foreign_keys: value = dict((k.lower(), v) for k, v in value.items()) - name = value.get('constraint_name', None) + name = value.get("constraint_name", None) if name is None: - name = '%s_%s_%s' % (value['from'], value['table'], value['to']) + name = "%s_%s_%s" % (value["from"], value["table"], value["to"]) if name not in foreign_keys: - if 'on_delete' not in value or value['on_delete'] == 'RESTRICT': - value['on_delete'] = None + if "on_delete" not in value or value["on_delete"] == "RESTRICT": + value["on_delete"] = None - if 'on_update' not in value or value['on_update'] == 'RESTRICT': - value['on_update'] = None + if "on_update" not in value or value["on_update"] == "RESTRICT": + value["on_update"] = None foreign_keys[name] = { - 'name': name, - 'local': [], - 'foreign': [], - 'foreign_table': value['table'], - 'on_delete': value['on_delete'], - 'on_update': value['on_update'], - 'deferrable': value.get('deferrable', False), - 'deferred': value.get('deferred', False) + "name": name, + "local": [], + "foreign": [], + "foreign_table": value["table"], + "on_delete": value["on_delete"], + "on_update": value["on_update"], + "deferrable": value.get("deferrable", False), + "deferred": value.get("deferred", False), } - foreign_keys[name]['local'].append(value['from']) - foreign_keys[name]['foreign'].append(value['to']) + foreign_keys[name]["local"].append(value["from"]) + foreign_keys[name]["foreign"].append(value["to"]) result = [] for constraint in foreign_keys.values(): result.append( ForeignKeyConstraint( - constraint['local'], constraint['foreign_table'], - constraint['foreign'], constraint['name'], + constraint["local"], + constraint["foreign_table"], + constraint["foreign"], + constraint["name"], { - 'on_delete': constraint['on_delete'], - 'on_update': constraint['on_update'], - 'deferrable': constraint['deferrable'], - 'deferred': constraint['deferred'] - } + "on_delete": constraint["on_delete"], + "on_update": constraint["on_update"], + "deferrable": constraint["deferrable"], + "deferred": constraint["deferred"], + }, ) ) diff --git a/orator/dbal/table.py b/orator/dbal/table.py index 887c56a0..9eaefe1f 100644 --- a/orator/dbal/table.py +++ b/orator/dbal/table.py @@ -8,15 +8,19 @@ from .foreign_key_constraint import ForeignKeyConstraint from .exceptions import ( DBALException, - IndexDoesNotExist, IndexAlreadyExists, IndexNameInvalid, - ColumnDoesNotExist, ColumnAlreadyExists, - ForeignKeyDoesNotExist + IndexDoesNotExist, + IndexAlreadyExists, + IndexNameInvalid, + ColumnDoesNotExist, + ColumnAlreadyExists, + ForeignKeyDoesNotExist, ) class Table(AbstractAsset): - - def __init__(self, table_name, columns=None, indexes=None, fk_constraints=None, options=None): + def __init__( + self, table_name, columns=None, indexes=None, fk_constraints=None, options=None + ): self._set_name(table_name) self._primary_key_name = False self._columns = OrderedDict() @@ -37,7 +41,11 @@ def __init__(self, table_name, columns=None, indexes=None, fk_constraints=None, for index in indexes: self._add_index(index) - fk_constraints = fk_constraints.values() if isinstance(fk_constraints, dict) else fk_constraints + fk_constraints = ( + fk_constraints.values() + if isinstance(fk_constraints, dict) + else fk_constraints + ) for constraint in fk_constraints: self._add_foreign_key_constraint(constraint) @@ -53,7 +61,9 @@ def set_primary_key(self, columns, index_name=False): :rtype: Table """ - self._add_index(self._create_index(columns, index_name or 'primary', True, True)) + self._add_index( + self._create_index(columns, index_name or "primary", True, True) + ) for column_name in columns: column = self.get_column(column_name) @@ -64,10 +74,12 @@ def set_primary_key(self, columns, index_name=False): def add_index(self, columns, name=None, flags=None, options=None): if not name: name = self._generate_identifier_name( - [self.get_name()] + columns, 'idx', self._get_max_identifier_length() + [self.get_name()] + columns, "idx", self._get_max_identifier_length() ) - return self._add_index(self._create_index(columns, name, False, False, flags, options)) + return self._add_index( + self._create_index(columns, name, False, False, flags, options) + ) def drop_primary_key(self): """ @@ -92,10 +104,12 @@ def drop_index(self, name): def add_unique_index(self, columns, name=None, options=None): if not name: name = self._generate_identifier_name( - [self.get_name()] + columns, 'uniq', self._get_max_identifier_length() + [self.get_name()] + columns, "uniq", self._get_max_identifier_length() ) - return self._add_index(self._create_index(columns, name, True, False, None, options)) + return self._add_index( + self._create_index(columns, name, True, False, None, options) + ) def rename_index(self, old_name, new_name=None): """ @@ -149,7 +163,9 @@ def columns_are_indexed(self, columns): return False - def _create_index(self, columns, name, is_unique, is_primary, flags=None, options=None): + def _create_index( + self, columns, name, is_unique, is_primary, flags=None, options=None + ): """ Creates an Index instance. @@ -173,7 +189,7 @@ def _create_index(self, columns, name, is_unique, is_primary, flags=None, option :rtype: Index """ - if re.match('[^a-zA-Z0-9_]+', self._normalize_identifier(name)): + if re.match("[^a-zA-Z0-9_]+", self._normalize_identifier(name)): raise IndexNameInvalid(name) for column in columns: @@ -237,8 +253,14 @@ def drop_column(self, name): return self - def add_foreign_key_constraint(self, foreign_table, local_columns, - foreign_columns, options=None, constraint_name=None): + def add_foreign_key_constraint( + self, + foreign_table, + local_columns, + foreign_columns, + options=None, + constraint_name=None, + ): """ Adds a foreign key constraint. @@ -259,16 +281,18 @@ def add_foreign_key_constraint(self, foreign_table, local_columns, """ if not constraint_name: constraint_name = self._generate_identifier_name( - [self.get_name()] + local_columns, 'fk', self._get_max_identifier_length() + [self.get_name()] + local_columns, + "fk", + self._get_max_identifier_length(), ) return self.add_named_foreign_key_constraint( - constraint_name, foreign_table, - local_columns, foreign_columns, options + constraint_name, foreign_table, local_columns, foreign_columns, options ) - def add_named_foreign_key_constraint(self, name, foreign_table, - local_columns, foreign_columns, options): + def add_named_foreign_key_constraint( + self, name, foreign_table, local_columns, foreign_columns, options + ): """ Adds a foreign key constraint with a given name. @@ -334,8 +358,10 @@ def _add_index(self, index): replaced_implicit_indexes.append(name) already_exists = ( - index_name in self._indexes and index_name not in replaced_implicit_indexes - or self._primary_key_name is not False and index.is_primary() + index_name in self._indexes + and index_name not in replaced_implicit_indexes + or self._primary_key_name is not False + and index.is_primary() ) if already_exists: raise IndexAlreadyExists(index_name, self._name) @@ -366,7 +392,9 @@ def _add_foreign_key_constraint(self, constraint): name = constraint.get_name() else: name = self._generate_identifier_name( - [self.get_name()] + constraint.get_local_columns(), 'fk', self._get_max_identifier_length() + [self.get_name()] + constraint.get_local_columns(), + "fk", + self._get_max_identifier_length(), ) name = self._normalize_identifier(name) @@ -380,16 +408,20 @@ def _add_foreign_key_constraint(self, constraint): # This creates computation overhead in this case, however no duplicate indexes # are ever added (based on columns). index_name = self._generate_identifier_name( - [self.get_name()] + constraint.get_columns(), 'idx', self._get_max_identifier_length() + [self.get_name()] + constraint.get_columns(), + "idx", + self._get_max_identifier_length(), + ) + index_candidate = self._create_index( + constraint.get_columns(), index_name, False, False ) - index_candidate = self._create_index(constraint.get_columns(), index_name, False, False) for existing_index in self._indexes.values(): if index_candidate.is_fullfilled_by(existing_index): return - #self._add_index(index_candidate) - #self._implicit_indexes[self._normalize_identifier(index_name)] = index_candidate + # self._add_index(index_candidate) + # self._implicit_indexes[self._normalize_identifier(index_name)] = index_candidate return self @@ -550,19 +582,27 @@ def clone(self): table._primary_key_name = self._primary_key_name for k, column in self._columns.items(): - table._columns[k] = Column(column.get_name(), column.get_type(), column.to_dict()) + table._columns[k] = Column( + column.get_name(), column.get_type(), column.to_dict() + ) for k, index in self._indexes.items(): table._indexes[k] = Index( - index.get_name(), index.get_columns(), - index.is_unique(), index.is_primary(), - index.get_flags(), index.get_options() + index.get_name(), + index.get_columns(), + index.is_unique(), + index.is_primary(), + index.get_flags(), + index.get_options(), ) for k, fk in self._fk_constraints.items(): table._fk_constraints[k] = ForeignKeyConstraint( - fk.get_local_columns(), fk.get_foreign_table_name(), - fk.get_foreign_columns(), fk.get_name(), fk.get_options() + fk.get_local_columns(), + fk.get_foreign_table_name(), + fk.get_foreign_columns(), + fk.get_name(), + fk.get_options(), ) table._fk_constraints[k].set_local_table(table) diff --git a/orator/dbal/table_diff.py b/orator/dbal/table_diff.py index c06a7647..dc012fd5 100644 --- a/orator/dbal/table_diff.py +++ b/orator/dbal/table_diff.py @@ -6,10 +6,17 @@ class TableDiff(object): - - def __init__(self, table_name, added_columns=None, - changed_columns=None, removed_columns=None, added_indexes=None, - changed_indexes=None, removed_indexes=None, from_table=None): + def __init__( + self, + table_name, + added_columns=None, + changed_columns=None, + removed_columns=None, + added_indexes=None, + changed_indexes=None, + removed_indexes=None, + from_table=None, + ): self.name = table_name self.new_name = False self.added_columns = added_columns or OrderedDict() diff --git a/orator/dbal/types/__init__.py b/orator/dbal/types/__init__.py index 633f8661..40a96afc 100644 --- a/orator/dbal/types/__init__.py +++ b/orator/dbal/types/__init__.py @@ -1,2 +1 @@ # -*- coding: utf-8 -*- - diff --git a/orator/events/__init__.py b/orator/events/__init__.py index fe7f5601..c4ad62e5 100644 --- a/orator/events/__init__.py +++ b/orator/events/__init__.py @@ -9,7 +9,7 @@ class Event(object): @classmethod def fire(cls, name, *args, **kwargs): - name = 'orator.%s' % name + name = "orator.%s" % name signal = cls.events.signal(name) for response in signal.send(*args, **kwargs): @@ -18,14 +18,14 @@ def fire(cls, name, *args, **kwargs): @classmethod def listen(cls, name, callback, *args, **kwargs): - name = 'orator.%s' % name + name = "orator.%s" % name signal = cls.events.signal(name) signal.connect(callback, weak=False, *args, **kwargs) @classmethod def forget(cls, name, *args, **kwargs): - name = 'orator.%s' % name + name = "orator.%s" % name signal = cls.events.signal(name) for receiver in signal.receivers: diff --git a/orator/exceptions/connection.py b/orator/exceptions/connection.py index c8581c5d..d7ff7869 100644 --- a/orator/exceptions/connection.py +++ b/orator/exceptions/connection.py @@ -2,7 +2,6 @@ class TransactionError(ConnectionError): - def __init__(self, previous, message=None): self.previous = previous - self.message = 'Transaction Error: ' + self.message = "Transaction Error: " diff --git a/orator/exceptions/connectors.py b/orator/exceptions/connectors.py index b3a0a831..5cbdbb05 100644 --- a/orator/exceptions/connectors.py +++ b/orator/exceptions/connectors.py @@ -7,7 +7,6 @@ class ConnectorException(Exception): class UnsupportedDriver(ConnectorException): - def __init__(self, driver): message = 'Driver "%s" is not supported' % driver @@ -15,7 +14,6 @@ def __init__(self, driver): class MissingPackage(ConnectorException): - def __init__(self, driver, supported_packages): if not isinstance(supported_packages, list): supported_packages = [supported_packages] @@ -24,6 +22,8 @@ def __init__(self, driver, supported_packages): if len(supported_packages) == 1: message += '"%s" package' % supported_packages[0] else: - message += 'one of the following packages: "%s"' % ('", "'.join(supported_packages)) - + message += 'one of the following packages: "%s"' % ( + '", "'.join(supported_packages) + ) + super(MissingPackage, self).__init__(message) diff --git a/orator/exceptions/orm.py b/orator/exceptions/orm.py index e52dbf4b..ef7d8a13 100644 --- a/orator/exceptions/orm.py +++ b/orator/exceptions/orm.py @@ -2,11 +2,10 @@ class ModelNotFound(RuntimeError): - def __init__(self, model): self._model = model - self.message = 'No query results found for model [%s]' % self._model.__name__ + self.message = "No query results found for model [%s]" % self._model.__name__ def __str__(self): return self.message @@ -17,7 +16,6 @@ class MassAssignmentError(RuntimeError): class RelatedClassNotFound(RuntimeError): - def __init__(self, related): self._related = related diff --git a/orator/exceptions/query.py b/orator/exceptions/query.py index 1360b541..bafa1408 100644 --- a/orator/exceptions/query.py +++ b/orator/exceptions/query.py @@ -2,7 +2,6 @@ class QueryException(Exception): - def __init__(self, sql, bindings, previous): self.sql = sql self.bindings = bindings @@ -10,7 +9,7 @@ def __init__(self, sql, bindings, previous): self.message = self.format_message(sql, bindings, previous) def format_message(self, sql, bindings, previous): - return '%s (SQL: %s (%s))' % (str(previous), sql, bindings) + return "%s (SQL: %s (%s))" % (str(previous), sql, bindings) def __repr__(self): return self.message diff --git a/orator/migrations/database_migration_repository.py b/orator/migrations/database_migration_repository.py index 8a7b1699..fc9e9777 100644 --- a/orator/migrations/database_migration_repository.py +++ b/orator/migrations/database_migration_repository.py @@ -4,7 +4,6 @@ class DatabaseMigrationRepository(object): - def __init__(self, resolver, table): """ :type resolver: orator.database_manager.DatabaseManager @@ -20,7 +19,7 @@ def get_ran(self): :rtype: list """ - return self.table().lists('migration') + return self.table().lists("migration") def get_last(self): """ @@ -28,9 +27,9 @@ def get_last(self): :rtype: list """ - query = self.table().where('batch', self.get_last_batch_number()) + query = self.table().where("batch", self.get_last_batch_number()) - return query.order_by('migration', 'desc').get() + return query.order_by("migration", "desc").get() def log(self, file, batch): """ @@ -39,10 +38,7 @@ def log(self, file, batch): :type file: str :type batch: int """ - record = { - 'migration': file, - 'batch': batch - } + record = {"migration": file, "batch": batch} self.table().insert(**record) @@ -52,7 +48,7 @@ def delete(self, migration): :type migration: dict """ - self.table().where('migration', migration['migration']).delete() + self.table().where("migration", migration["migration"]).delete() def get_next_batch_number(self): """ @@ -68,7 +64,7 @@ def get_last_batch_number(self): :rtype: int """ - return self.table().max('batch') or 0 + return self.table().max("batch") or 0 def create_repository(self): """ @@ -80,8 +76,8 @@ def create_repository(self): # The migrations table is responsible for keeping track of which of the # migrations have actually run for the application. We'll create the # table to hold the migration file's path as well as the batch ID. - table.string('migration') - table.integer('batch') + table.string("migration") + table.integer("batch") def repository_exists(self): """ diff --git a/orator/migrations/migration_creator.py b/orator/migrations/migration_creator.py index 2b267acb..b421b46e 100644 --- a/orator/migrations/migration_creator.py +++ b/orator/migrations/migration_creator.py @@ -18,7 +18,6 @@ def mkdir_p(path): class MigrationCreator(object): - def create(self, name, path, table=None, create=False): """ Create a new migration at the given path. @@ -38,14 +37,14 @@ def create(self, name, path, table=None, create=False): if not os.path.exists(os.path.dirname(path)): mkdir_p(os.path.dirname(path)) - parent = os.path.join(os.path.dirname(path), '__init__.py') + parent = os.path.join(os.path.dirname(path), "__init__.py") if not os.path.exists(parent): - with open(parent, 'w'): + with open(parent, "w"): pass stub = self._get_stub(table, create) - with open(path, 'w') as fh: + with open(path, "w") as fh: fh.write(self._populate_stub(name, stub, table)) return path @@ -87,10 +86,10 @@ def _populate_stub(self, name, stub, table): :rtype: str """ - stub = stub.replace('DummyClass', self._get_class_name(name)) + stub = stub.replace("DummyClass", self._get_class_name(name)) if table is not None: - stub = stub.replace('dummy_table', table) + stub = stub.replace("dummy_table", table) return stub @@ -98,7 +97,7 @@ def _get_class_name(self, name): return inflection.camelize(name) def _get_path(self, name, path): - return os.path.join(path, self._get_date_prefix() + '_' + name + '.py') + return os.path.join(path, self._get_date_prefix() + "_" + name + ".py") def _get_date_prefix(self): - return datetime.datetime.utcnow().strftime('%Y_%m_%d_%H%M%S') + return datetime.datetime.utcnow().strftime("%Y_%m_%d_%H%M%S") diff --git a/orator/migrations/migrator.py b/orator/migrations/migrator.py index 432b4ddb..006205e3 100644 --- a/orator/migrations/migrator.py +++ b/orator/migrations/migrator.py @@ -11,7 +11,6 @@ class MigratorHandler(logging.NullHandler): - def __init__(self, level=logging.DEBUG): super(MigratorHandler, self).__init__(level) @@ -22,7 +21,6 @@ def handle(self, record): class Migrator(object): - def __init__(self, repository, resolver): """ :type repository: DatabaseMigrationRepository @@ -61,7 +59,7 @@ def run_migration_list(self, path, migrations, pretend=False): :type pretend: bool """ if not migrations: - self._note('Nothing to migrate') + self._note("Nothing to migrate") return @@ -83,7 +81,7 @@ def _run_up(self, path, migration_file, batch, pretend=False): migration = self._resolve(path, migration_file) if pretend: - return self._pretend_to_run(migration, 'up') + return self._pretend_to_run(migration, "up") if migration.transactional: with migration.db.transaction(): @@ -93,7 +91,10 @@ def _run_up(self, path, migration_file, batch, pretend=False): self._repository.log(migration_file, batch) - self._note(decode('[OK] Migrated ') + '%s' % migration_file) + self._note( + decode("[OK] Migrated ") + + "%s" % migration_file + ) def rollback(self, path, pretend=False): """ @@ -112,7 +113,7 @@ def rollback(self, path, pretend=False): migrations = self._repository.get_last() if not migrations: - self._note('Nothing to rollback.') + self._note("Nothing to rollback.") return len(migrations) @@ -140,10 +141,10 @@ def reset(self, path, pretend=False): count = len(migrations) if count == 0: - self._note('Nothing to rollback.') + self._note("Nothing to rollback.") else: for migration in migrations: - self._run_down(path, {'migration': migration}, pretend) + self._run_down(path, {"migration": migration}, pretend) return count @@ -151,12 +152,12 @@ def _run_down(self, path, migration, pretend=False): """ Run "down" a migration instance. """ - migration_file = migration['migration'] + migration_file = migration["migration"] instance = self._resolve(path, migration_file) if pretend: - return self._pretend_to_run(instance, 'down') + return self._pretend_to_run(instance, "down") if instance.transactional: with instance.db.transaction(): @@ -166,7 +167,10 @@ def _run_down(self, path, migration, pretend=False): self._repository.delete(migration) - self._note(decode('[OK] Rolled back ') + '%s' % migration_file) + self._note( + decode("[OK] Rolled back ") + + "%s" % migration_file + ) def _get_migration_files(self, path): """ @@ -176,12 +180,12 @@ def _get_migration_files(self, path): :rtype: list """ - files = glob.glob(os.path.join(path, '[0-9]*_*.py')) + files = glob.glob(os.path.join(path, "[0-9]*_*.py")) if not files: return [] - files = list(map(lambda f: os.path.basename(f).replace('.py', ''), files)) + files = list(map(lambda f: os.path.basename(f).replace(".py", ""), files)) files = sorted(files) @@ -197,7 +201,7 @@ def _pretend_to_run(self, migration, method): :param method: The method to execute :type method: str """ - self._note('') + self._note("") names = [] for query in self._get_queries(migration, method): name = migration.__class__.__name__ @@ -206,17 +210,13 @@ def _pretend_to_run(self, migration, method): if isinstance(query, tuple): query, bindings = query - query = highlight( - query, - SqlLexer(), - CommandFormatter() - ).strip() + query = highlight(query, SqlLexer(), CommandFormatter()).strip() if bindings: query = (query, bindings) if name not in names: - self._note('[{}]'.format(name)) + self._note("[{}]".format(name)) names.append(name) self._note(query) @@ -250,19 +250,19 @@ def _resolve(self, path, migration_file): :rtype: orator.migrations.migration.Migration """ - name = '_'.join(migration_file.split('_')[4:]) - migration_file = os.path.join(path, '%s.py' % migration_file) + name = "_".join(migration_file.split("_")[4:]) + migration_file = os.path.join(path, "%s.py" % migration_file) # Loading parent module - parent = os.path.join(path, '__init__.py') + parent = os.path.join(path, "__init__.py") if not os.path.exists(parent): - with open(parent, 'w'): + with open(parent, "w"): pass - load_module('migrations', parent) + load_module("migrations", parent) # Loading module - mod = load_module('migrations.%s' % name, migration_file) + mod = load_module("migrations.%s" % name, migration_file) klass = getattr(mod, inflection.camelize(name)) diff --git a/orator/orm/__init__.py b/orator/orm/__init__.py index b9a222c2..fc37731f 100644 --- a/orator/orm/__init__.py +++ b/orator/orm/__init__.py @@ -6,10 +6,18 @@ from .collection import Collection from .factory import Factory from .utils import ( - mutator, accessor, column, - has_one, morph_one, - belongs_to, morph_to, - has_many, has_many_through, morph_many, - belongs_to_many, morph_to_many, morphed_by_many, - scope + mutator, + accessor, + column, + has_one, + morph_one, + belongs_to, + morph_to, + has_many, + has_many_through, + morph_many, + belongs_to_many, + morph_to_many, + morphed_by_many, + scope, ) diff --git a/orator/orm/builder.py b/orator/orm/builder.py index d86d80aa..72a7441e 100644 --- a/orator/orm/builder.py +++ b/orator/orm/builder.py @@ -13,8 +13,19 @@ class Builder(object): _passthru = [ - 'to_sql', 'lists', 'insert', 'insert_get_id', 'pluck', 'count', - 'min', 'max', 'avg', 'sum', 'exists', 'get_bindings', 'raw' + "to_sql", + "lists", + "insert", + "insert_get_id", + "pluck", + "count", + "min", + "max", + "avg", + "sum", + "exists", + "get_bindings", + "raw", ] def __init__(self, query): @@ -97,12 +108,12 @@ def find(self, id, columns=None): :rtype: orator.Model """ if columns is None: - columns = ['*'] + columns = ["*"] if isinstance(id, list): return self.find_many(id, columns) - self._query.where(self._model.get_qualified_key_name(), '=', id) + self._query.where(self._model.get_qualified_key_name(), "=", id) return self.first(columns) @@ -120,7 +131,7 @@ def find_many(self, id, columns=None): :rtype: orator.Collection """ if columns is None: - columns = ['*'] + columns = ["*"] if not id: return self._model.new_collection() @@ -165,7 +176,7 @@ def first(self, columns=None): :rtype: mixed """ if columns is None: - columns = ['*'] + columns = ["*"] return self.take(1).get(columns).first() @@ -232,15 +243,19 @@ def chunk(self, count): :return: The current chunk :rtype: list """ - page = 1 - results = self.for_page(page, count).get() + connection = self._model.get_connection_name() + for results in self.apply_scopes().get_query().chunk(count): + models = self._model.hydrate(results, connection) - while not results.is_empty(): - yield results + # If we actually found models we will also eager load any relationships that + # have been specified as needing to be eager loaded, which will solve the + # n+1 query issue for the developers to avoid running a lot of queries. + if len(models) > 0: + models = self.eager_load_relations(models) - page += 1 + collection = self._model.new_collection(models) - results = self.for_page(page, count).get() + yield collection def lists(self, column, key=None): """ @@ -287,7 +302,7 @@ def paginate(self, per_page=None, current_page=None, columns=None): :return: The paginator """ if columns is None: - columns = ['*'] + columns = ["*"] total = self.to_base().get_count_for_pagination() @@ -313,7 +328,7 @@ def simple_paginate(self, per_page=None, current_page=None, columns=None): :return: The paginator """ if columns is None: - columns = ['*'] + columns = ["*"] page = current_page or Paginator.resolve_current_page() per_page = per_page or self._model.get_per_page() @@ -398,7 +413,8 @@ def _add_updated_at_column(self, values): column = self._model.get_updated_at_column() - values.update({column: self._model.fresh_timestamp()}) + if "updated_at" not in values: + values.update({column: self._model.fresh_timestamp_string()}) return values @@ -455,7 +471,7 @@ def eager_load_relations(self, models): :rtype: list """ for name, constraints in self._eager_load.items(): - if name.find('.') == -1: + if name.find(".") == -1: models = self._load_relation(models, name, constraints) return models @@ -509,7 +525,7 @@ def _nested_relations(self, relation): for name, constraints in self._eager_load.items(): if self._is_nested(name, relation): - nested[name[len(relation + '.'):]] = constraints + nested[name[len(relation + ".") :]] = constraints return nested @@ -522,11 +538,11 @@ def _is_nested(self, name, relation): :rtype: bool """ - dots = name.find('.') + dots = name.find(".") - return dots and name.startswith(relation + '.') + return dots and name.startswith(relation + ".") - def where(self, column, operator=Null(), value=None, boolean='and'): + def where(self, column, operator=Null(), value=None, boolean="and"): """ Add a where clause to the query @@ -568,9 +584,9 @@ def or_where(self, column, operator=None, value=None): :return: The current Builder instance :rtype: Builder """ - return self.where(column, operator, value, 'or') + return self.where(column, operator, value, "or") - def where_exists(self, query, boolean='and', negate=False): + def where_exists(self, query, boolean="and", negate=False): """ Add an exists clause to the query. @@ -601,9 +617,9 @@ def or_where_exists(self, query, negate=False): :rtype: Builder """ - return self.where_exists(query, 'or', negate) + return self.where_exists(query, "or", negate) - def where_not_exists(self, query, boolean='and'): + def where_not_exists(self, query, boolean="and"): """ Add a where not exists clause to the query. @@ -627,7 +643,7 @@ def or_where_not_exists(self, query): """ return self.or_where_exists(query, True) - def has(self, relation, operator='>=', count=1, boolean='and', extra=None): + def has(self, relation, operator=">=", count=1, boolean="and", extra=None): """ Add a relationship count condition to the query. @@ -648,21 +664,25 @@ def has(self, relation, operator='>=', count=1, boolean='and', extra=None): :type: Builder """ - if relation.find('.') >= 0: + if relation.find(".") >= 0: return self._has_nested(relation, operator, count, boolean, extra) relation = self._get_has_relation_query(relation) - query = relation.get_relation_count_query(relation.get_related().new_query(), self) + query = relation.get_relation_count_query( + relation.get_related().new_query(), self + ) # TODO: extra query if extra: if callable(extra): extra(query) - return self._add_has_where(query.apply_scopes(), relation, operator, count, boolean) + return self._add_has_where( + query.apply_scopes(), relation, operator, count, boolean + ) - def _has_nested(self, relations, operator='>=', count=1, boolean='and', extra=None): + def _has_nested(self, relations, operator=">=", count=1, boolean="and", extra=None): """ Add nested relationship count conditions to the query. @@ -683,7 +703,7 @@ def _has_nested(self, relations, operator='>=', count=1, boolean='and', extra=No :rtype: Builder """ - relations = relations.split('.') + relations = relations.split(".") def closure(q): if len(relations) > 1: @@ -693,7 +713,7 @@ def closure(q): return self.where_has(relations.pop(0), closure) - def doesnt_have(self, relation, boolean='and', extra=None): + def doesnt_have(self, relation, boolean="and", extra=None): """ Add a relationship count to the query. @@ -708,9 +728,9 @@ def doesnt_have(self, relation, boolean='and', extra=None): :rtype: Builder """ - return self.has(relation, '<', 1, boolean, extra) + return self.has(relation, "<", 1, boolean, extra) - def where_has(self, relation, extra, operator='>=', count=1): + def where_has(self, relation, extra, operator=">=", count=1): """ Add a relationship count condition to the query with where clauses. @@ -728,7 +748,7 @@ def where_has(self, relation, extra, operator='>=', count=1): :rtype: Builder """ - return self.has(relation, operator, count, 'and', extra) + return self.has(relation, operator, count, "and", extra) def where_doesnt_have(self, relation, extra=None): """ @@ -742,9 +762,9 @@ def where_doesnt_have(self, relation, extra=None): :rtype: Builder """ - return self.doesnt_have(relation, 'and', extra) + return self.doesnt_have(relation, "and", extra) - def or_has(self, relation, operator='>=', count=1): + def or_has(self, relation, operator=">=", count=1): """ Add a relationship count condition to the query with an "or". @@ -759,9 +779,9 @@ def or_has(self, relation, operator='>=', count=1): :rtype: Builder """ - return self.has(relation, operator, count, 'or') + return self.has(relation, operator, count, "or") - def or_where_has(self, relation, extra, operator='>=', count=1): + def or_where_has(self, relation, extra, operator=">=", count=1): """ Add a relationship count condition to the query with where clauses and an "or". @@ -779,7 +799,7 @@ def or_where_has(self, relation, extra, operator='>=', count=1): :rtype: Builder """ - return self.has(relation, operator, count, 'or', extra) + return self.has(relation, operator, count, "or", extra) def _add_has_where(self, has_query, relation, operator, count, boolean): """ @@ -807,7 +827,9 @@ def _add_has_where(self, has_query, relation, operator, count, boolean): if isinstance(count, basestring) and count.isdigit(): count = QueryExpression(count) - return self.where(QueryExpression('(%s)' % has_query.to_sql()), operator, count, boolean) + return self.where( + QueryExpression("(%s)" % has_query.to_sql()), operator, count, boolean + ) def _merge_model_defined_relation_wheres_to_has_query(self, has_query, relation): """ @@ -821,11 +843,9 @@ def _merge_model_defined_relation_wheres_to_has_query(self, has_query, relation) """ relation_query = relation.get_base_query() - has_query.merge_wheres( - relation_query.wheres, relation_query.get_bindings() - ) + has_query.merge_wheres(relation_query.wheres, relation_query.get_bindings()) - self._query.add_binding(has_query.get_query().get_bindings(), 'where') + self._query.add_binding(has_query.get_query().get_bindings(), "where") def _get_has_relation_query(self, relation): """ @@ -869,8 +889,12 @@ def _parse_with_relations(self, relations): for relation in relations: if isinstance(relation, dict): - name = list(relation.keys())[0] - constraints = relation[name] + for name, constraints in relation.items(): + results = self._parse_nested_with(name, results) + + results[name] = constraints + + continue else: name = relation constraints = self.__class__(self.get_query().new_query()) @@ -894,10 +918,10 @@ def _parse_nested_with(self, name, results): """ progress = [] - for segment in name.split('.'): + for segment in name.split("."): progress.append(segment) - last = '.'.join(progress) + last = ".".join(progress) if last not in results: results[last] = self.__class__(self.get_query().new_query()) @@ -920,7 +944,9 @@ def _call_scope(self, scope, *args, **kwargs): result = getattr(self._model, scope)(self, *args, **kwargs) if self._should_nest_wheres_for_scope(query, original_where_count): - self._nest_wheres_for_scope(query, [0, original_where_count, len(query.wheres)]) + self._nest_wheres_for_scope( + query, [0, original_where_count, len(query.wheres)] + ) return result or self @@ -1022,13 +1048,9 @@ def _slice_where_conditions(self, wheres, offset, length): :rtype: list """ where_group = self.get_query().for_nested_where() - where_group.wheres = wheres[offset:(offset + length)] + where_group.wheres = wheres[offset : (offset + length)] - return { - 'type': 'nested', - 'query': where_group, - 'boolean': 'and' - } + return {"type": "nested", "query": where_group, "boolean": "and"} def get_query(self): """ @@ -1124,12 +1146,14 @@ def get_macro(self, name): def __dynamic(self, method): from .utils import scope - scope_method = 'scope_%s' % method + scope_method = "scope_%s" % method is_scope = False is_macro = False # New scope definition check - if hasattr(self._model, method) and isinstance(getattr(self._model, method), scope): + if hasattr(self._model, method) and isinstance( + getattr(self._model, method), scope + ): is_scope = True attribute = getattr(self._model, method) scope_method = method @@ -1172,4 +1196,3 @@ def __copy__(self): new.set_model(self._model) return new - diff --git a/orator/orm/collection.py b/orator/orm/collection.py index 2024295a..72ac873b 100644 --- a/orator/orm/collection.py +++ b/orator/orm/collection.py @@ -4,7 +4,6 @@ class Collection(BaseCollection): - def load(self, *relations): """ Load a set of relationships onto the collection. diff --git a/orator/orm/factory.py b/orator/orm/factory.py index d02f62e3..560914da 100644 --- a/orator/orm/factory.py +++ b/orator/orm/factory.py @@ -8,7 +8,6 @@ class Factory(object): - def __init__(self, faker=None, resolver=None): """ :param faker: A faker generator instance @@ -56,7 +55,7 @@ def define_as(self, klass, name): """ return self.define(klass, name) - def define(self, klass, name='default'): + def define(self, klass, name="default"): """ Define a class with a given set of attributes. @@ -66,6 +65,7 @@ def define(self, klass, name='default'): :param name: The short name :type name: str """ + def decorate(func): @wraps(func) def wrapped(*args, **kwargs): @@ -77,7 +77,7 @@ def wrapped(*args, **kwargs): return decorate - def register(self, klass, callback, name='default'): + def register(self, klass, callback, name="default"): """ Register a class with a function. @@ -189,7 +189,7 @@ def raw_of(self, klass, name, **attributes): """ return self.raw(klass, _name=name, **attributes) - def raw(self, klass, _name='default', **attributes): + def raw(self, klass, _name="default", **attributes): """ Get the raw attribute dict for a given named model. @@ -210,7 +210,7 @@ def raw(self, klass, _name='default', **attributes): return raw - def of(self, klass, name='default'): + def of(self, klass, name="default"): """ Create a builder for the given model. @@ -222,9 +222,11 @@ def of(self, klass, name='default'): :return: orator.orm.factory_builder.FactoryBuilder """ - return FactoryBuilder(klass, name, self._definitions, self._faker, self._resolver) + return FactoryBuilder( + klass, name, self._definitions, self._faker, self._resolver + ) - def build(self, klass, name='default', amount=None): + def build(self, klass, name="default", amount=None): """ Makes a factory builder with a specified amount. @@ -242,7 +244,7 @@ def build(self, klass, name='default', amount=None): if amount is None: if isinstance(name, int): amount = name - name = 'default' + name = "default" else: amount = 1 @@ -287,7 +289,7 @@ def __setitem__(self, key, value): def __contains__(self, item): return item in self._definitions - def __call__(self, klass, name='default', amount=None): + def __call__(self, klass, name="default", amount=None): """ Makes a factory builder with a specified amount. diff --git a/orator/orm/factory_builder.py b/orator/orm/factory_builder.py index 93e285ec..960cd07a 100644 --- a/orator/orm/factory_builder.py +++ b/orator/orm/factory_builder.py @@ -4,7 +4,6 @@ class FactoryBuilder(object): - def __init__(self, klass, name, definitions, faker, resolver=None): """ :param klass: The class diff --git a/orator/orm/mixins/soft_deletes.py b/orator/orm/mixins/soft_deletes.py index 45be01aa..57949e15 100644 --- a/orator/orm/mixins/soft_deletes.py +++ b/orator/orm/mixins/soft_deletes.py @@ -35,7 +35,11 @@ def _do_perform_delete_on_model(self): Perform the actual delete query on this model instance. """ if self.__force_deleting__: - return self.with_trashed().where(self.get_key_name(), self.get_key()).force_delete() + return ( + self.with_trashed() + .where(self.get_key_name(), self.get_key()) + .force_delete() + ) return self._run_soft_delete() @@ -48,15 +52,13 @@ def _run_soft_delete(self): time = self.fresh_timestamp() setattr(self, self.get_deleted_at_column(), time) - query.update({ - self.get_deleted_at_column(): self.from_datetime(time) - }) + query.update({self.get_deleted_at_column(): self.from_datetime(time)}) def restore(self): """ Restore a soft-deleted model instance. """ - if self._fire_model_event('restoring') is False: + if self._fire_model_event("restoring") is False: return False setattr(self, self.get_deleted_at_column(), None) @@ -65,7 +67,7 @@ def restore(self): result = self.save() - self._fire_model_event('restored') + self._fire_model_event("restored") return result @@ -99,7 +101,9 @@ def only_trashed(cls): column = instance.get_qualified_deleted_at_column() - return instance.new_query_without_scope(SoftDeletingScope()).where_not_null(column) + return instance.new_query_without_scope(SoftDeletingScope()).where_not_null( + column + ) @classmethod def restoring(cls, callback): @@ -108,7 +112,7 @@ def restoring(cls, callback): :type callback: callable """ - cls._register_model_event('restoring', callback) + cls._register_model_event("restoring", callback) @classmethod def restored(cls, callback): @@ -117,7 +121,7 @@ def restored(cls, callback): :type callback: callable """ - cls._register_model_event('restored', callback) + cls._register_model_event("restored", callback) def get_deleted_at_column(self): """ @@ -125,7 +129,7 @@ def get_deleted_at_column(self): :rtype: str """ - return getattr(self, 'DELETED_AT', 'deleted_at') + return getattr(self, "DELETED_AT", "deleted_at") def get_qualified_deleted_at_column(self): """ @@ -133,4 +137,4 @@ def get_qualified_deleted_at_column(self): :rtype: str """ - return '%s.%s' % (self.get_table(), self.get_deleted_at_column()) + return "%s.%s" % (self.get_table(), self.get_deleted_at_column()) diff --git a/orator/orm/model.py b/orator/orm/model.py index 06b2fedb..f5e62887 100644 --- a/orator/orm/model.py +++ b/orator/orm/model.py @@ -5,6 +5,7 @@ import inflection import inspect import uuid +import datetime from warnings import warn from six import add_metaclass from collections import OrderedDict @@ -14,8 +15,16 @@ from .builder import Builder from .collection import Collection from .relations import ( - Relation, HasOne, HasMany, BelongsTo, BelongsToMany, HasManyThrough, - MorphOne, MorphMany, MorphTo, MorphToMany + Relation, + HasOne, + HasMany, + BelongsTo, + BelongsToMany, + HasManyThrough, + MorphOne, + MorphMany, + MorphTo, + MorphToMany, ) from .relations.wrapper import Wrapper, BelongsToManyWrapper from .utils import mutator, accessor @@ -24,7 +33,6 @@ class ModelRegister(dict): - def __init__(self, *args, **kwargs): self.inverse = {} @@ -67,12 +75,12 @@ class Model(object): __table__ = None - __primary_key__ = 'id' + __primary_key__ = "id" __incrementing__ = True __fillable__ = [] - __guarded__ = ['*'] + __guarded__ = ["*"] __unguarded__ = False __hidden__ = [] @@ -109,10 +117,10 @@ class Model(object): __attributes__ = {} - many_methods = ['belongs_to_many', 'morph_to_many', 'morphed_by_many'] + many_methods = ["belongs_to_many", "morph_to_many", "morphed_by_many"] - CREATED_AT = 'created_at' - UPDATED_AT = 'updated_at' + CREATED_AT = "created_at" + UPDATED_AT = "updated_at" def __init__(self, _attributes=None, **attributes): """ @@ -143,11 +151,11 @@ def _boot_if_not_booted(self): if not klass._booted.get(klass): klass._booted[klass] = True - self._fire_model_event('booting') + self._fire_model_event("booting") klass._boot() - self._fire_model_event('booted') + self._fire_model_event("booted") @classmethod def _boot(cls): @@ -168,7 +176,9 @@ def _boot(cls): @classmethod def _boot_columns(cls): connection = cls.resolve_connection() - columns = connection.get_schema_manager().list_table_columns(cls.__table__ or inflection.tableize(cls.__name__)) + columns = connection.get_schema_manager().list_table_columns( + cls.__table__ or inflection.tableize(cls.__name__) + ) cls.__columns__ = list(columns.keys()) @classmethod @@ -177,15 +187,15 @@ def _boot_mixins(cls): Boot the mixins """ for mixin in cls.__bases__: - #if mixin == Model: + # if mixin == Model: # continue - method = 'boot_%s' % inflection.underscore(mixin.__name__) + method = "boot_%s" % inflection.underscore(mixin.__name__) if hasattr(mixin, method): getattr(mixin, method)(cls) @classmethod - def add_global_scope(cls, scope, implementation = None): + def add_global_scope(cls, scope, implementation=None): """ Register a new global scope on the model. @@ -205,7 +215,7 @@ def add_global_scope(cls, scope, implementation = None): elif isinstance(scope, Scope): cls._global_scopes[cls][scope.__class__] = scope else: - raise Exception('Global scope must be an instance of Scope or a callable') + raise Exception("Global scope must be an instance of Scope or a callable") @classmethod def has_global_scope(cls, scope): @@ -565,7 +575,7 @@ def find(cls, id, columns=None): return instance.new_collection() if columns is None: - columns = ['*'] + columns = ["*"] return instance.new_query().find(id, columns) @@ -639,9 +649,9 @@ def with_(cls, *relations): return instance.new_query().with_(*relations) - def has_one(self, related, - foreign_key=None, local_key=None, relation=None, - _wrapped=True): + def has_one( + self, related, foreign_key=None, local_key=None, relation=None, _wrapped=True + ): """ Define a one to one relationship. @@ -675,15 +685,19 @@ def has_one(self, related, if not local_key: local_key = self.get_key_name() - rel = HasOne(instance.new_query(), - self, - '%s.%s' % (instance.get_table(), foreign_key), - local_key) + rel = HasOne( + instance.new_query(), + self, + "%s.%s" % (instance.get_table(), foreign_key), + local_key, + ) if _wrapped: - warn('Using has_one method directly is deprecated. ' - 'Use the appropriate decorator instead.', - category=DeprecationWarning) + warn( + "Using has_one method directly is deprecated. " + "Use the appropriate decorator instead.", + category=DeprecationWarning, + ) rel = Wrapper(rel) @@ -691,10 +705,16 @@ def has_one(self, related, return rel - def morph_one(self, related, name, - type_column=None, id_column=None, - local_key=None, relation=None, - _wrapped=True): + def morph_one( + self, + related, + name, + type_column=None, + id_column=None, + local_key=None, + relation=None, + _wrapped=True, + ): """ Define a polymorphic one to one relationship. @@ -730,14 +750,20 @@ def morph_one(self, related, name, if not local_key: local_key = self.get_key_name() - rel = MorphOne(instance.new_query(), self, - '%s.%s' % (table, type_column), - '%s.%s' % (table, id_column), local_key) + rel = MorphOne( + instance.new_query(), + self, + "%s.%s" % (table, type_column), + "%s.%s" % (table, id_column), + local_key, + ) if _wrapped: - warn('Using morph_one method directly is deprecated. ' - 'Use the appropriate decorator instead.', - category=DeprecationWarning) + warn( + "Using morph_one method directly is deprecated. " + "Use the appropriate decorator instead.", + category=DeprecationWarning, + ) rel = Wrapper(rel) @@ -745,9 +771,9 @@ def morph_one(self, related, name, return rel - def belongs_to(self, related, - foreign_key=None, other_key=None, relation=None, - _wrapped=True): + def belongs_to( + self, related, foreign_key=None, other_key=None, relation=None, _wrapped=True + ): """ Define an inverse one to one or many relationship. @@ -771,7 +797,7 @@ def belongs_to(self, related, return self._relations[relation] if foreign_key is None: - foreign_key = '%s_id' % inflection.underscore(relation) + foreign_key = "%s_id" % inflection.underscore(relation) instance = self._get_related(related, True) @@ -783,9 +809,11 @@ def belongs_to(self, related, rel = BelongsTo(query, self, foreign_key, other_key, relation) if _wrapped: - warn('Using belongs_to method directly is deprecated. ' - 'Use the appropriate decorator instead.', - category=DeprecationWarning) + warn( + "Using belongs_to method directly is deprecated. " + "Use the appropriate decorator instead.", + category=DeprecationWarning, + ) rel = Wrapper(rel) @@ -793,8 +821,7 @@ def belongs_to(self, related, return rel - def morph_to(self, name=None, type_column=None, id_column=None, - _wrapped=True): + def morph_to(self, name=None, type_column=None, id_column=None, _wrapped=True): """ Define a polymorphic, inverse one-to-one or many relationship. @@ -831,14 +858,21 @@ def morph_to(self, name=None, type_column=None, id_column=None, instance = klass() instance.set_connection(self.get_connection_name()) - rel = MorphTo(instance.new_query(), - self, id_column, - instance.get_key_name(), type_column, name) + rel = MorphTo( + instance.new_query(), + self, + id_column, + instance.get_key_name(), + type_column, + name, + ) if _wrapped: - warn('Using morph_to method directly is deprecated. ' - 'Use the appropriate decorator instead.', - category=DeprecationWarning) + warn( + "Using morph_to method directly is deprecated. " + "Use the appropriate decorator instead.", + category=DeprecationWarning, + ) rel = Wrapper(rel) @@ -858,8 +892,9 @@ def get_actual_class_for_morph(self, slug): if morph_name == slug: return cls - def has_many(self, related, foreign_key=None, local_key=None, - relation=None, _wrapped=True): + def has_many( + self, related, foreign_key=None, local_key=None, relation=None, _wrapped=True + ): """ Define a one to many relationship. @@ -893,15 +928,19 @@ def has_many(self, related, foreign_key=None, local_key=None, if not local_key: local_key = self.get_key_name() - rel = HasMany(instance.new_query(), - self, - '%s.%s' % (instance.get_table(), foreign_key), - local_key) + rel = HasMany( + instance.new_query(), + self, + "%s.%s" % (instance.get_table(), foreign_key), + local_key, + ) if _wrapped: - warn('Using has_many method directly is deprecated. ' - 'Use the appropriate decorator instead.', - category=DeprecationWarning) + warn( + "Using has_many method directly is deprecated. " + "Use the appropriate decorator instead.", + category=DeprecationWarning, + ) rel = Wrapper(rel) @@ -909,9 +948,15 @@ def has_many(self, related, foreign_key=None, local_key=None, return rel - def has_many_through(self, related, through, - first_key=None, second_key=None, relation=None, - _wrapped=True): + def has_many_through( + self, + related, + through, + first_key=None, + second_key=None, + relation=None, + _wrapped=True, + ): """ Define a has-many-through relationship. @@ -948,13 +993,20 @@ def has_many_through(self, related, through, if not second_key: second_key = through.get_foreign_key() - rel = HasManyThrough(self._get_related(related)().new_query(), - self, through, first_key, second_key) + rel = HasManyThrough( + self._get_related(related)().new_query(), + self, + through, + first_key, + second_key, + ) if _wrapped: - warn('Using has_many_through method directly is deprecated. ' - 'Use the appropriate decorator instead.', - category=DeprecationWarning) + warn( + "Using has_many_through method directly is deprecated. " + "Use the appropriate decorator instead.", + category=DeprecationWarning, + ) rel = Wrapper(rel) @@ -962,10 +1014,16 @@ def has_many_through(self, related, through, return rel - def morph_many(self, related, name, - type_column=None, id_column=None, - local_key=None, relation=None, - _wrapped=True): + def morph_many( + self, + related, + name, + type_column=None, + id_column=None, + local_key=None, + relation=None, + _wrapped=True, + ): """ Define a polymorphic one to many relationship. @@ -1001,14 +1059,20 @@ def morph_many(self, related, name, if not local_key: local_key = self.get_key_name() - rel = MorphMany(instance.new_query(), self, - '%s.%s' % (table, type_column), - '%s.%s' % (table, id_column), local_key) + rel = MorphMany( + instance.new_query(), + self, + "%s.%s" % (table, type_column), + "%s.%s" % (table, id_column), + local_key, + ) if _wrapped: - warn('Using morph_many method directly is deprecated. ' - 'Use the appropriate decorator instead.', - category=DeprecationWarning) + warn( + "Using morph_many method directly is deprecated. " + "Use the appropriate decorator instead.", + category=DeprecationWarning, + ) rel = Wrapper(rel) @@ -1016,9 +1080,15 @@ def morph_many(self, related, name, return rel - def belongs_to_many(self, related, table=None, - foreign_key=None, other_key=None, - relation=None, _wrapped=True): + def belongs_to_many( + self, + related, + table=None, + foreign_key=None, + other_key=None, + relation=None, + _wrapped=True, + ): """ Define a many-to-many relationship. @@ -1060,9 +1130,11 @@ def belongs_to_many(self, related, table=None, rel = BelongsToMany(query, self, table, foreign_key, other_key, relation) if _wrapped: - warn('Using belongs_to_many method directly is deprecated. ' - 'Use the appropriate decorator instead.', - category=DeprecationWarning) + warn( + "Using belongs_to_many method directly is deprecated. " + "Use the appropriate decorator instead.", + category=DeprecationWarning, + ) rel = BelongsToManyWrapper(rel) @@ -1070,10 +1142,17 @@ def belongs_to_many(self, related, table=None, return rel - def morph_to_many(self, related, name, table=None, - foreign_key=None, other_key=None, - inverse=False, relation=None, - _wrapped=True): + def morph_to_many( + self, + related, + name, + table=None, + foreign_key=None, + other_key=None, + inverse=False, + relation=None, + _wrapped=True, + ): """ Define a polymorphic many-to-many relationship. @@ -1106,7 +1185,7 @@ def morph_to_many(self, related, name, table=None, return self._relations[caller] if not foreign_key: - foreign_key = name + '_id' + foreign_key = name + "_id" instance = self._get_related(related, True) @@ -1118,13 +1197,16 @@ def morph_to_many(self, related, name, table=None, if not table: table = inflection.pluralize(name) - rel = MorphToMany(query, self, name, table, - foreign_key, other_key, caller, inverse) + rel = MorphToMany( + query, self, name, table, foreign_key, other_key, caller, inverse + ) if _wrapped: - warn('Using morph_to_many method directly is deprecated. ' - 'Use the appropriate decorator instead.', - category=DeprecationWarning) + warn( + "Using morph_to_many method directly is deprecated. " + "Use the appropriate decorator instead.", + category=DeprecationWarning, + ) rel = Wrapper(rel) @@ -1132,7 +1214,16 @@ def morph_to_many(self, related, name, table=None, return rel - def morphed_by_many(self, related, name, table=None, foreign_key=None, other_key=None, relation=None): + def morphed_by_many( + self, + related, + name, + table=None, + foreign_key=None, + other_key=None, + relation=None, + _wrapped=False, + ): """ Define a polymorphic many-to-many relationship. @@ -1160,9 +1251,11 @@ def morphed_by_many(self, related, name, table=None, foreign_key=None, other_key foreign_key = self.get_foreign_key() if not other_key: - other_key = name + '_id' + other_key = name + "_id" - return self.morph_to_many(related, name, table, foreign_key, other_key, True, relation) + return self.morph_to_many( + related, name, table, foreign_key, other_key, True, relation, _wrapped + ) def _get_related(self, related, as_instance=False): """ @@ -1210,7 +1303,7 @@ def joining_table(self, related): models = sorted([related, base]) - return '_'.join(models) + return "_".join(models) @classmethod def destroy(cls, *ids): @@ -1249,10 +1342,10 @@ def delete(self): :raises: Exception """ if self.__primary_key__ is None: - raise Exception('No primary key defined on the model.') + raise Exception("No primary key defined on the model.") if self._exists: - if self._fire_model_event('deleting') is False: + if self._fire_model_event("deleting") is False: return False self.touch_owners() @@ -1261,7 +1354,7 @@ def delete(self): self._exists = False - self._fire_model_event('deleted') + self._fire_model_event("deleted") return True @@ -1275,7 +1368,7 @@ def _perform_delete_on_model(self): """ Perform the actual delete query on this model instance. """ - if hasattr(self, '_do_perform_delete_on_model'): + if hasattr(self, "_do_perform_delete_on_model"): return self._do_perform_delete_on_model() return self.new_query().where(self.get_key_name(), self.get_key()).delete() @@ -1287,7 +1380,7 @@ def saving(cls, callback): :type callback: callable """ - cls._register_model_event('saving', callback) + cls._register_model_event("saving", callback) @classmethod def saved(cls, callback): @@ -1296,7 +1389,7 @@ def saved(cls, callback): :type callback: callable """ - cls._register_model_event('saved', callback) + cls._register_model_event("saved", callback) @classmethod def updating(cls, callback): @@ -1305,7 +1398,7 @@ def updating(cls, callback): :type callback: callable """ - cls._register_model_event('updating', callback) + cls._register_model_event("updating", callback) @classmethod def updated(cls, callback): @@ -1314,7 +1407,7 @@ def updated(cls, callback): :type callback: callable """ - cls._register_model_event('updated', callback) + cls._register_model_event("updated", callback) @classmethod def creating(cls, callback): @@ -1323,7 +1416,7 @@ def creating(cls, callback): :type callback: callable """ - cls._register_model_event('creating', callback) + cls._register_model_event("creating", callback) @classmethod def created(cls, callback): @@ -1332,7 +1425,7 @@ def created(cls, callback): :type callback: callable """ - cls._register_model_event('created', callback) + cls._register_model_event("created", callback) @classmethod def deleting(cls, callback): @@ -1341,7 +1434,7 @@ def deleting(cls, callback): :type callback: callable """ - cls._register_model_event('deleting', callback) + cls._register_model_event("deleting", callback) @classmethod def deleted(cls, callback): @@ -1350,7 +1443,7 @@ def deleted(cls, callback): :type callback: callable """ - cls._register_model_event('deleted', callback) + cls._register_model_event("deleted", callback) @classmethod def flush_event_listeners(cls): @@ -1361,7 +1454,7 @@ def flush_event_listeners(cls): return for event in cls.get_observable_events(): - cls.__dispatcher__.forget('%s: %s' % (event, cls.__name__)) + cls.__dispatcher__.forget("%s: %s" % (event, cls.__name__)) @classmethod def _register_model_event(cls, event, callback): @@ -1375,7 +1468,7 @@ def _register_model_event(cls, event, callback): :type callback: callable """ if cls.__dispatcher__: - cls.__dispatcher__.listen('%s: %s' % (event, cls.__name__), callback) + cls.__dispatcher__.listen("%s: %s" % (event, cls.__name__), callback) @classmethod def get_observable_events(cls): @@ -1385,9 +1478,16 @@ def get_observable_events(cls): :rtype: list """ default_events = [ - 'creating', 'created', 'updating', 'updated', - 'deleting', 'deleted', 'saving', 'saved', - 'restoring', 'restored' + "creating", + "created", + "updating", + "updated", + "deleting", + "deleted", + "saving", + "saved", + "restoring", + "restored", ] return default_events + cls.__observables__ @@ -1405,7 +1505,7 @@ def _increment(self, column, amount=1): :return: The new column value :rtype: int """ - return self._increment_or_decrement(column, amount, 'increment') + return self._increment_or_decrement(column, amount, "increment") def _decrement(self, column, amount=1): """ @@ -1420,7 +1520,7 @@ def _decrement(self, column, amount=1): :return: The new column value :rtype: int """ - return self._increment_or_decrement(column, amount, 'decrement') + return self._increment_or_decrement(column, amount, "decrement") def _increment_or_decrement(self, column, amount, method): """ @@ -1464,7 +1564,11 @@ def _increment_or_decrement_attribute_value(self, column, amount, method): :return: None """ - setattr(self, column, getattr(self, column) + (amount if method == 'increment' else amount * -1)) + setattr( + self, + column, + getattr(self, column) + (amount if method == "increment" else amount * -1), + ) self.sync_original_attribute(column) @@ -1517,7 +1621,7 @@ def save(self, options=None): query = self.new_query() - if self._fire_model_event('saving') is False: + if self._fire_model_event("saving") is False: return False if self._exists: @@ -1534,11 +1638,11 @@ def _finish_save(self, options): """ Finish processing on a successful save operation. """ - self._fire_model_event('saved') + self._fire_model_event("saved") self.sync_original() - if options.get('touch', True): + if options.get("touch", True): self.touch_owners() def _perform_update(self, query, options=None): @@ -1557,10 +1661,10 @@ def _perform_update(self, query, options=None): dirty = self.get_dirty() if len(dirty): - if self._fire_model_event('updating') is False: + if self._fire_model_event("updating") is False: return False - if self.__timestamps__ and options.get('timestamps', True): + if self.__timestamps__ and options.get("timestamps", True): self._update_timestamps() dirty = self.get_dirty() @@ -1568,7 +1672,7 @@ def _perform_update(self, query, options=None): if len(dirty): self._set_keys_for_save_query(query).update(dirty) - self._fire_model_event('updated') + self._fire_model_event("updated") return True @@ -1585,10 +1689,10 @@ def _perform_insert(self, query, options=None): if options is None: options = {} - if self._fire_model_event('creating') is False: + if self._fire_model_event("creating") is False: return False - if self.__timestamps__ and options.get('timestamps', True): + if self.__timestamps__ and options.get("timestamps", True): self._update_timestamps() attributes = self._attributes @@ -1600,7 +1704,7 @@ def _perform_insert(self, query, options=None): self._exists = True - self._fire_model_event('created') + self._fire_model_event("created") return True @@ -1627,9 +1731,9 @@ def touch_owners(self): for relation in self.__touches__: if hasattr(self, relation): _relation = getattr(self, relation) - _relation().touch() - if _relation is not None: + if _relation: + _relation.touch() _relation.touch_owners() def touches(self, relation): @@ -1655,7 +1759,7 @@ def _fire_model_event(self, event): # We will append the names of the class to the event to distinguish it from # other model events that are fired, allowing us to listen on each model # event set individually instead of catching event for all the models. - event = '%s: %s' % (event, self.__class__.__name__) + event = "%s: %s" % (event, self.__class__.__name__) return self.__dispatcher__.fire(event, self) @@ -1701,10 +1805,16 @@ def _update_timestamps(self): """ time = self.fresh_timestamp() - if not self.is_dirty(self.UPDATED_AT) and self._should_set_timestamp(self.UPDATED_AT): + if not self.is_dirty(self.UPDATED_AT) and self._should_set_timestamp( + self.UPDATED_AT + ): self.set_updated_at(time) - if not self._exists and not self.is_dirty(self.CREATED_AT) and self._should_set_timestamp(self.CREATED_AT): + if ( + not self._exists + and not self.is_dirty(self.CREATED_AT) + and self._should_set_timestamp(self.CREATED_AT) + ): self.set_created_at(time) def _should_set_timestamp(self, timestamp): @@ -1763,6 +1873,14 @@ def fresh_timestamp(self): """ return pendulum.utcnow() + def fresh_timestamp_string(self): + """ + Get a fresh timestamp string for the model. + + :return: str + """ + return self.from_datetime(self.fresh_timestamp()) + def new_query(self): """ Get a new query builder for the model's table @@ -1795,9 +1913,7 @@ def new_query_without_scopes(self): :return: A Builder instance :rtype: Builder """ - builder = self.new_orm_builder( - self._new_base_query_builder() - ) + builder = self.new_orm_builder(self._new_base_query_builder()) return builder.set_model(self).with_(*self._with) @@ -1919,7 +2035,7 @@ def get_qualified_key_name(self): :rtype: str """ - return '%s.%s' % (self.get_table(), self.get_key_name()) + return "%s.%s" % (self.get_table(), self.get_key_name()) def uses_timestamps(self): """ @@ -1934,10 +2050,10 @@ def get_morphs(self, name, type, id): Get the polymorphic relationship columns. """ if not type: - type = name + '_type' + type = name + "_type" if not id: - id = name + '_id' + id = name + "_id" return type, id @@ -1965,7 +2081,7 @@ def get_foreign_key(self): :rtype: str """ - return '%s_id' % inflection.singularize(self.get_table()) + return "%s_id" % inflection.singularize(self.get_table()) def get_hidden(self): """ @@ -2095,7 +2211,7 @@ def is_fillable(self, key): if self.is_guarded(key): return False - return not self.__fillable__ and not key.startswith('_') + return not self.__fillable__ and not key.startswith("_") def is_guarded(self, key): """ @@ -2107,7 +2223,7 @@ def is_guarded(self, key): :return: Whether the attribute is guarded or not :rtype: bool """ - return key in self.__guarded__ or self.__guarded__ == ['*'] + return key in self.__guarded__ or self.__guarded__ == ["*"] def totally_guarded(self): """ @@ -2115,7 +2231,7 @@ def totally_guarded(self): :rtype: bool """ - return len(self.__fillable__) == 0 and self.__guarded__ == ['*'] + return len(self.__fillable__) == 0 and self.__guarded__ == ["*"] def _remove_table_from_key(self, key): """ @@ -2126,10 +2242,10 @@ def _remove_table_from_key(self, key): :rtype: str """ - if '.' not in key: + if "." not in key: return key - return key.split('.')[-1] + return key.split(".")[-1] def get_incrementing(self): return self.__incrementing__ @@ -2242,9 +2358,9 @@ def relations_to_dict(self): continue relation = None - if hasattr(value, 'serialize'): + if hasattr(value, "serialize"): relation = value.serialize() - elif hasattr(value, 'to_dict'): + elif hasattr(value, "to_dict"): relation = value.to_dict() elif value is None: relation = value @@ -2272,7 +2388,11 @@ def _get_dictable_items(self, values): if len(self.__visible__) > 0: return {x: values[x] for x in values.keys() if x in self.__visible__} - return {x: values[x] for x in values.keys() if x not in self.__hidden__ and not x.startswith('_')} + return { + x: values[x] + for x in values.keys() + if x not in self.__hidden__ and not x.startswith("_") + } def get_attribute(self, key, original=None): """ @@ -2337,7 +2457,9 @@ def _get_relationship_from_method(self, method, relations=None): relations = relations or super(Model, self).__getattribute__(method) if not isinstance(relations, Relation): - raise RuntimeError('Relationship method must return an object of type Relation') + raise RuntimeError( + "Relationship method must return an object of type Relation" + ) self._relations[method] = relations @@ -2352,7 +2474,7 @@ def has_get_mutator(self, key): :rtype: bool """ - return hasattr(self, 'get_%s_attribute' % inflection.underscore(key)) + return hasattr(self, "get_%s_attribute" % inflection.underscore(key)) def _mutate_attribute_for_dict(self, key): """ @@ -2363,7 +2485,7 @@ def _mutate_attribute_for_dict(self, key): """ value = getattr(self, key) - if hasattr(value, 'to_dict'): + if hasattr(value, "to_dict"): return value.to_dict() if key in self.get_dates(): @@ -2409,7 +2531,7 @@ def _is_json_castable(self, key): if self._has_cast(key): type = self._get_cast_type(key) - return type in ['list', 'dict', 'json', 'object'] + return type in ["list", "dict", "json", "object"] return False @@ -2440,15 +2562,15 @@ def _cast_attribute(self, key, value): return None type = self._get_cast_type(key) - if type in ['int', 'integer']: + if type in ["int", "integer"]: return int(value) - elif type in ['real', 'float', 'double']: + elif type in ["real", "float", "double"]: return float(value) - elif type in ['string', 'str']: + elif type in ["string", "str"]: return str(value) - elif type in ['bool', 'boolean']: + elif type in ["bool", "boolean"]: return bool(value) - elif type in ['dict', 'list', 'json'] and isinstance(value, basestring): + elif type in ["dict", "list", "json"] and isinstance(value, basestring): return json.loads(value) else: return value @@ -2465,20 +2587,32 @@ def get_dates(self): def from_datetime(self, value): """ - Convert datetime to a datetime object + Convert datetime to a storable string. + + :param value: The datetime value + :type value: pendulum.Pendulum or datetime.date or datetime.datetime - :rtype: datetime.datetime + :rtype: str """ + date_format = self.get_connection().get_query_grammar().get_date_format() + if isinstance(value, pendulum.Pendulum): - return value + return value.format(date_format) - return pendulum.instance(value) + if isinstance(value, datetime.date) and not isinstance( + value, (datetime.datetime) + ): + value = pendulum.date.instance(value) + + return value.format(date_format) + + return pendulum.instance(value).format(date_format) def as_datetime(self, value): """ Return a timestamp as a datetime. - :rtype: pendulum.Pendulum + :rtype: pendulum.Pendulum or pendulum.Date """ if isinstance(value, basestring): return pendulum.parse(value) @@ -2486,6 +2620,11 @@ def as_datetime(self, value): if isinstance(value, (int, float)): return pendulum.from_timestamp(value) + if isinstance(value, datetime.date) and not isinstance( + value, (datetime.datetime) + ): + return pendulum.date.instance(value) + return pendulum.instance(value) def get_date_format(self): @@ -2494,7 +2633,7 @@ def get_date_format(self): :rtype: str """ - return 'iso' + return "iso" def _format_date(self, date): """ @@ -2510,7 +2649,7 @@ def _format_date(self, date): format = self.get_date_format() - if format == 'iso': + if format == "iso": if isinstance(date, basestring): return pendulum.parse(date).isoformat() @@ -2549,10 +2688,12 @@ def replicate(self, except_=None): except_ = [ self.get_key_name(), self.get_created_at_column(), - self.get_updated_at_column() + self.get_updated_at_column(), ] - attributes = {x: self._attributes[x] for x in self._attributes if x not in except_} + attributes = { + x: self._attributes[x] for x in self._attributes if x not in except_ + } instance = self.new_instance(attributes) @@ -2816,7 +2957,12 @@ def __getattr__(self, item): return self.get_attribute(item) def __setattr__(self, key, value): - if key in ['_attributes', '_exists', '_relations', '_original'] or key.startswith('__'): + if key in [ + "_attributes", + "_exists", + "_relations", + "_original", + ] or key.startswith("__"): return object.__setattr__(self, key, value) if self._has_set_mutator(key): @@ -2841,14 +2987,14 @@ def __delattr__(self, item): def __getstate__(self): return { - 'attributes': self._attributes, - 'relations': self._relations, - 'exists': self._exists + "attributes": self._attributes, + "relations": self._relations, + "exists": self._exists, } def __setstate__(self, state): self._boot_if_not_booted() - self.set_raw_attributes(state['attributes'], True) - self.set_relations(state['relations']) - self.set_exists(state['exists']) + self.set_raw_attributes(state["attributes"], True) + self.set_relations(state["relations"]) + self.set_exists(state["exists"]) diff --git a/orator/orm/relations/belongs_to.py b/orator/orm/relations/belongs_to.py index 12b7d8de..062e0c5f 100644 --- a/orator/orm/relations/belongs_to.py +++ b/orator/orm/relations/belongs_to.py @@ -6,7 +6,6 @@ class BelongsTo(Relation): - def __init__(self, query, parent, foreign_key, other_key, relation): """ :param query: A Builder instance @@ -34,6 +33,9 @@ def get_results(self): """ Get the results of the relationship. """ + if self._query is None: + return None + return self._query.first() def add_constraints(self): @@ -43,9 +45,15 @@ def add_constraints(self): :rtype: None """ if self._constraints: - table = self._related.get_table() + foreign_key = getattr(self._parent, self._foreign_key, None) + if foreign_key is None: + self._query = None + else: + table = self._related.get_table() - self._query.where('%s.%s' % (table, self._other_key), '=', getattr(self._parent, self._foreign_key)) + self._query.where( + "{}.{}".format(table, self._other_key), "=", foreign_key + ) def get_relation_count_query(self, query, parent): """ @@ -56,11 +64,15 @@ def get_relation_count_query(self, query, parent): :rtype: Builder """ - query.select(QueryExpression('COUNT(*)')) + query.select(QueryExpression("COUNT(*)")) - other_key = self.wrap('%s.%s' % (query.get_model().get_table(), self._other_key)) + other_key = self.wrap( + "%s.%s" % (query.get_model().get_table(), self._other_key) + ) - return query.where(self.get_qualified_foreign_key(), '=', QueryExpression(other_key)) + return query.where( + self.get_qualified_foreign_key(), "=", QueryExpression(other_key) + ) def add_eager_constraints(self, models): """ @@ -68,7 +80,7 @@ def add_eager_constraints(self, models): :type models: list """ - key = '%s.%s' % (self._related.get_table(), self._other_key) + key = "%s.%s" % (self._related.get_table(), self._other_key) self._query.where_in(key, self._get_eager_model_keys(models)) @@ -142,9 +154,13 @@ def associate(self, model): :rtype: orator.Model """ - self._parent.set_attribute(self._foreign_key, model.get_attribute(self._other_key)) + self._parent.set_attribute( + self._foreign_key, model.get_attribute(self._other_key) + ) - return self._parent.set_relation(self._relation, Result(model, self, self._parent)) + return self._parent.set_relation( + self._relation, Result(model, self, self._parent) + ) def dissociate(self): """ @@ -154,7 +170,9 @@ def dissociate(self): """ self._parent.set_attribute(self._foreign_key, None) - return self._parent.set_relation(self._relation, Result(None, self, self._parent)) + return self._parent.set_relation( + self._relation, Result(None, self, self._parent) + ) def update(self, _attributes=None, **attributes): """ @@ -176,19 +194,15 @@ def get_foreign_key(self): return self._foreign_key def get_qualified_foreign_key(self): - return '%s.%s' % (self._parent.get_table(), self._foreign_key) + return "%s.%s" % (self._parent.get_table(), self._foreign_key) def get_other_key(self): return self._other_key def get_qualified_other_key_name(self): - return '%s.%s' % (self._related.get_table(), self._other_key) + return "%s.%s" % (self._related.get_table(), self._other_key) def _new_instance(self, model): return BelongsTo( - self.new_query(), - model, - self._foreign_key, - self._other_key, - self._relation + self.new_query(), model, self._foreign_key, self._other_key, self._relation ) diff --git a/orator/orm/relations/belongs_to_many.py b/orator/orm/relations/belongs_to_many.py index b2273587..8918769b 100644 --- a/orator/orm/relations/belongs_to_many.py +++ b/orator/orm/relations/belongs_to_many.py @@ -21,7 +21,9 @@ class BelongsToMany(Relation): _pivot_columns = [] _pivot_wheres = [] - def __init__(self, query, parent, table, foreign_key, other_key, relation_name=None): + def __init__( + self, query, parent, table, foreign_key, other_key, relation_name=None + ): """ :param query: A Builder instance :type query: Builder @@ -57,7 +59,7 @@ def get_results(self): """ return self.get() - def where_pivot(self, column, operator=None, value=None, boolean='and'): + def where_pivot(self, column, operator=None, value=None, boolean="and"): """ Set a where clause for a pivot table column. @@ -78,7 +80,9 @@ def where_pivot(self, column, operator=None, value=None, boolean='and'): """ self._pivot_wheres.append([column, operator, value, boolean]) - return self._query.where('%s.%s' % (self._table, column), operator, value, boolean) + return self._query.where( + "%s.%s" % (self._table, column), operator, value, boolean + ) def or_where_pivot(self, column, operator=None, value=None): """ @@ -96,7 +100,7 @@ def or_where_pivot(self, column, operator=None, value=None): :return: self :rtype: BelongsToMany """ - return self.where_pivot(column, operator, value, 'or') + return self.where_pivot(column, operator, value, "or") def first(self, columns=None): """ @@ -136,7 +140,7 @@ def get(self, columns=None): :rtype: orator.Collection """ if columns is None: - columns = ['*'] + columns = ["*"] if self._query.get_query().columns: columns = [] @@ -161,7 +165,7 @@ def _hydrate_pivot_relation(self, models): for model in models: pivot = self.new_existing_pivot(self._clean_pivot_attributes(model)) - model.set_relation('pivot', pivot) + model.set_relation("pivot", pivot) def _clean_pivot_attributes(self, model): """ @@ -173,7 +177,7 @@ def _clean_pivot_attributes(self, model): delete_keys = [] for key, value in model.get_attributes().items(): - if key.find('pivot_') == 0: + if key.find("pivot_") == 0: values[key[6:]] = value delete_keys.append(key) @@ -219,16 +223,18 @@ def get_relation_count_query_for_self_join(self, query, parent): :rtype: orator.orm.Builder """ - query.select(QueryExpression('COUNT(*)')) + query.select(QueryExpression("COUNT(*)")) table_prefix = self._query.get_query().get_connection().get_table_prefix() hash_ = self.get_relation_count_hash() - query.from_('%s AS %s%s' % (self._table, table_prefix, hash_)) + query.from_("%s AS %s%s" % (self._table, table_prefix, hash_)) key = self.wrap(self.get_qualified_parent_key_name()) - return query.where('%s.%s' % (hash_, self._foreign_key), '=', QueryExpression(key)) + return query.where( + "%s.%s" % (hash_, self._foreign_key), "=", QueryExpression(key) + ) def get_relation_count_hash(self): """ @@ -236,7 +242,7 @@ def get_relation_count_hash(self): :rtype: str """ - return 'self_%s' % (hashlib.md5(str(time.time()).encode()).hexdigest()) + return "self_%s" % (hashlib.md5(str(time.time()).encode()).hexdigest()) def _get_select_columns(self, columns=None): """ @@ -247,8 +253,8 @@ def _get_select_columns(self, columns=None): :rtype: list """ - if columns == ['*'] or columns is None: - columns = ['%s.*' % self._related.get_table()] + if columns == ["*"] or columns is None: + columns = ["%s.*" % self._related.get_table()] return columns + self._get_aliased_pivot_columns() @@ -263,9 +269,9 @@ def _get_aliased_pivot_columns(self): columns = [] for column in defaults + self._pivot_columns: - value = '%s.%s AS pivot_%s' % (self._table, column, column) + value = "%s.%s AS pivot_%s" % (self._table, column, column) if value not in columns: - columns.append('%s.%s AS pivot_%s' % (self._table, column, column)) + columns.append("%s.%s AS pivot_%s" % (self._table, column, column)) return columns @@ -295,9 +301,9 @@ def _set_join(self, query=None): base_table = self._related.get_table() - key = '%s.%s' % (base_table, self._related.get_key_name()) + key = "%s.%s" % (base_table, self._related.get_key_name()) - query.join(self._table, key, '=', self.get_other_key()) + query.join(self._table, key, "=", self.get_other_key()) return self @@ -310,7 +316,7 @@ def _set_where(self): """ foreign = self.get_foreign_key() - self._query.where(foreign, '=', self._parent.get_key()) + self._query.where(foreign, "=", self._parent.get_key()) return self @@ -330,7 +336,9 @@ def init_relation(self, models, relation): :type relation: str """ for model in models: - model.set_relation(relation, Result(self._related.new_collection(), self, model)) + model.set_relation( + relation, Result(self._related.new_collection(), self, model) + ) return models @@ -348,7 +356,9 @@ def match(self, models, results, relation): key = model.get_key() if key in dictionary: - collection = Result(self._related.new_collection(dictionary[key]), self, model) + collection = Result( + self._related.new_collection(dictionary[key]), self, model + ) else: collection = Result(self._related.new_collection(), self, model) @@ -416,7 +426,7 @@ def save(self, model, joining=None, touch=True): if joining is None: joining = {} - model.save({'touch': False}) + model.save({"touch": False}) self.attach(model.get_key(), joining, touch) @@ -477,7 +487,9 @@ def first_or_new(self, _attributes=None, **attributes): return instance - def first_or_create(self, _attributes=None, _joining=None, _touch=True, **attributes): + def first_or_create( + self, _attributes=None, _joining=None, _touch=True, **attributes + ): """ Get the first related model record matching the attributes or create it. @@ -517,7 +529,7 @@ def update_or_create(self, attributes, values=None, joining=None, touch=True): instance.fill(**values) - instance.save({'touch': False}) + instance.save({"touch": False}) return instance @@ -535,7 +547,7 @@ def create(self, _attributes=None, _joining=None, _touch=True, **attributes): instance = self._related.new_instance(attributes) - instance.save({'touch': False}) + instance.save({"touch": False}) self.attach(instance.get_key(), _joining, _touch) @@ -561,11 +573,7 @@ def sync(self, ids, detaching=True): """ Sync the intermediate tables with a list of IDs or collection of models """ - changes = { - 'attached': [], - 'detached': [], - 'updated': [] - } + changes = {"attached": [], "detached": [], "updated": []} if isinstance(ids, Collection): ids = ids.model_keys() @@ -579,11 +587,11 @@ def sync(self, ids, detaching=True): if detaching and len(detach) > 0: self.detach(detach) - changes['detached'] = detach + changes["detached"] = detach changes.update(self._attach_new(records, current, False)) - if len(changes['attached']) or len(changes['updated']): + if len(changes["attached"]) or len(changes["updated"]): self.touch_if_touching() return changes @@ -609,18 +617,17 @@ def _attach_new(self, records, current, touch=True): """ Attach all of the IDs that aren't in the current dict. """ - changes = { - 'attached': [], - 'updated': [] - } + changes = {"attached": [], "updated": []} for id, attributes in records.items(): if id not in current: self.attach(id, attributes, touch) - changes['attached'].append(id) - elif len(attributes) > 0 and self.update_existing_pivot(id, attributes, touch): - changes['updated'].append(id) + changes["attached"].append(id) + elif len(attributes) > 0 and self.update_existing_pivot( + id, attributes, touch + ): + changes["updated"].append(id) return changes @@ -661,8 +668,9 @@ def _create_attach_records(self, ids, attributes): """ records = [] - timed = (self._has_pivot_column(self.created_at()) - or self._has_pivot_column(self.updated_at())) + timed = self._has_pivot_column(self.created_at()) or self._has_pivot_column( + self.updated_at() + ) for key, value in enumerate(ids): records.append(self._attacher(key, value, attributes, timed)) @@ -765,7 +773,9 @@ def _touching_parent(self): return self.get_related().touches(self._guess_inverse_relation()) def _guess_inverse_relation(self): - return inflection.camelize(inflection.pluralize(self.get_parent().__class__.__name__)) + return inflection.camelize( + inflection.pluralize(self.get_parent().__class__.__name__) + ) def _new_pivot_query(self): """ @@ -841,10 +851,10 @@ def get_has_compare_key(self): return self.get_foreign_key() def get_foreign_key(self): - return '%s.%s' % (self._table, self._foreign_key) + return "%s.%s" % (self._table, self._foreign_key) def get_other_key(self): - return '%s.%s' % (self._table, self._other_key) + return "%s.%s" % (self._table, self._other_key) def get_table(self): return self._table @@ -859,7 +869,7 @@ def _new_instance(self, model): self._table, self._foreign_key, self._other_key, - self._relation_name + self._relation_name, ) relation.with_pivot(*self._pivot_columns) diff --git a/orator/orm/relations/has_many.py b/orator/orm/relations/has_many.py index 63e3b65b..dd9e19cd 100644 --- a/orator/orm/relations/has_many.py +++ b/orator/orm/relations/has_many.py @@ -5,7 +5,6 @@ class HasMany(HasOneOrMany): - def get_results(self): """ Get the results of the relationship. @@ -20,7 +19,9 @@ def init_relation(self, models, relation): :type relation: str """ for model in models: - model.set_relation(relation, Result(self._related.new_collection(), self, model)) + model.set_relation( + relation, Result(self._related.new_collection(), self, model) + ) return models @@ -35,9 +36,4 @@ def match(self, models, results, relation): return self.match_many(models, results, relation) def _new_instance(self, model): - return HasMany( - self.new_query(), - model, - self._foreign_key, - self._local_key - ) + return HasMany(self.new_query(), model, self._foreign_key, self._local_key) diff --git a/orator/orm/relations/has_many_through.py b/orator/orm/relations/has_many_through.py index a1f6eb81..c17db200 100644 --- a/orator/orm/relations/has_many_through.py +++ b/orator/orm/relations/has_many_through.py @@ -6,7 +6,6 @@ class HasManyThrough(Relation): - def __init__(self, query, far_parent, parent, first_key, second_key): """ :param query: A Builder instance @@ -38,7 +37,11 @@ def add_constraints(self): self._set_join() if self._constraints: - self._query.where('%s.%s' % (parent_table, self._first_key), '=', self._far_parent.get_key()) + self._query.where( + "%s.%s" % (parent_table, self._first_key), + "=", + self._far_parent.get_key(), + ) def get_relation_count_query(self, query, parent): """ @@ -53,11 +56,11 @@ def get_relation_count_query(self, query, parent): self._set_join(query) - query.select(QueryExpression('COUNT(*)')) + query.select(QueryExpression("COUNT(*)")) - key = self.wrap('%s.%s' % (parent_table, self._first_key)) + key = self.wrap("%s.%s" % (parent_table, self._first_key)) - return query.where(self.get_has_compare_key(), '=', QueryExpression(key)) + return query.where(self.get_has_compare_key(), "=", QueryExpression(key)) def _set_join(self, query=None): """ @@ -66,9 +69,14 @@ def _set_join(self, query=None): if not query: query = self._query - foreign_key = '%s.%s' % (self._related.get_table(), self._second_key) + foreign_key = "%s.%s" % (self._related.get_table(), self._second_key) - query.join(self._parent.get_table(), self.get_qualified_parent_key_name(), '=', foreign_key) + query.join( + self._parent.get_table(), + self.get_qualified_parent_key_name(), + "=", + foreign_key, + ) def add_eager_constraints(self, models): """ @@ -78,7 +86,7 @@ def add_eager_constraints(self, models): """ table = self._parent.get_table() - self._query.where_in('%s.%s' % (table, self._first_key), self.get_keys(models)) + self._query.where_in("%s.%s" % (table, self._first_key), self.get_keys(models)) def init_relation(self, models, relation): """ @@ -88,7 +96,9 @@ def init_relation(self, models, relation): :type relation: str """ for model in models: - model.set_relation(relation, Result(self._related.new_collection(), self, model)) + model.set_relation( + relation, Result(self._related.new_collection(), self, model) + ) return models @@ -106,7 +116,9 @@ def match(self, models, results, relation): key = model.get_key() if key in dictionary: - value = Result(self._related.new_collection(dictionary[key]), self, model) + value = Result( + self._related.new_collection(dictionary[key]), self, model + ) else: value = Result(self._related.new_collection(), self, model) @@ -151,7 +163,7 @@ def get(self, columns=None): :rtype: orator.Collection """ if columns is None: - columns = ['*'] + columns = ["*"] select = self._get_select_columns(columns) @@ -171,19 +183,15 @@ def _get_select_columns(self, columns=None): :rtype: list """ - if columns == ['*'] or columns is None: - columns = ['%s.*' % self._related.get_table()] + if columns == ["*"] or columns is None: + columns = ["%s.*" % self._related.get_table()] - return columns + ['%s.%s' % (self._parent.get_table(), self._first_key)] + return columns + ["%s.%s" % (self._parent.get_table(), self._first_key)] def get_has_compare_key(self): return self._far_parent.get_qualified_key_name() def _new_instance(self, model): return HasManyThrough( - self.new_query(), - model, - self._parent, - self._first_key, - self._second_key + self.new_query(), model, self._parent, self._first_key, self._second_key ) diff --git a/orator/orm/relations/has_one.py b/orator/orm/relations/has_one.py index 1830040f..a5bf69f7 100644 --- a/orator/orm/relations/has_one.py +++ b/orator/orm/relations/has_one.py @@ -5,7 +5,6 @@ class HasOne(HasOneOrMany): - def get_results(self): """ Get the results of the relationship. @@ -35,9 +34,4 @@ def match(self, models, results, relation): return self.match_one(models, results, relation) def _new_instance(self, model): - return HasOne( - self.new_query(), - model, - self._foreign_key, - self._local_key - ) + return HasOne(self.new_query(), model, self._foreign_key, self._local_key) diff --git a/orator/orm/relations/has_one_or_many.py b/orator/orm/relations/has_one_or_many.py index 716fe7fe..7e356e37 100644 --- a/orator/orm/relations/has_one_or_many.py +++ b/orator/orm/relations/has_one_or_many.py @@ -6,7 +6,6 @@ class HasOneOrMany(Relation): - def __init__(self, query, parent, foreign_key, local_key): """ :type query: orator.orm.Builder @@ -30,7 +29,7 @@ def add_constraints(self): Set the base constraints of the relation query """ if self._constraints: - self._query.where(self._foreign_key, '=', self.get_parent_key()) + self._query.where(self._foreign_key, "=", self.get_parent_key()) def add_eager_constraints(self, models): """ @@ -38,7 +37,9 @@ def add_eager_constraints(self, models): :type models: list """ - return self._query.where_in(self._foreign_key, self.get_keys(models, self._local_key)) + return self._query.where_in( + self._foreign_key, self.get_keys(models, self._local_key) + ) def match_one(self, models, results, relation): """ @@ -55,7 +56,7 @@ def match_one(self, models, results, relation): :rtype: list """ - return self._match_one_or_many(models, results, relation, 'one') + return self._match_one_or_many(models, results, relation, "one") def match_many(self, models, results, relation): """ @@ -72,7 +73,7 @@ def match_many(self, models, results, relation): :rtype: list """ - return self._match_one_or_many(models, results, relation, 'many') + return self._match_one_or_many(models, results, relation, "many") def _match_one_or_many(self, models, results, relation, type_): """ @@ -98,9 +99,11 @@ def _match_one_or_many(self, models, results, relation, type_): key = model.get_attribute(self._local_key) if key in dictionary: - value = Result(self._get_relation_value(dictionary, key, type_), self, model) + value = Result( + self._get_relation_value(dictionary, key, type_), self, model + ) else: - if type_ == 'one': + if type_ == "one": value = Result(None, self, model) else: value = Result(self._related.new_collection(), self, model) @@ -119,7 +122,7 @@ def _get_relation_value(self, dictionary, key, type): """ value = dictionary[key] - if type == 'one': + if type == "one": return value[0] return self._related.new_collection(value) @@ -171,7 +174,7 @@ def save_many(self, models): :rtype: list """ - return map(self.save, models) + return list(map(self.save, models)) def find_or_new(self, id, columns=None): """ @@ -186,7 +189,7 @@ def find_or_new(self, id, columns=None): :rtype: Collection or Model """ if columns is None: - columns = ['*'] + columns = ["*"] instance = self._query.find(id, columns) @@ -315,7 +318,7 @@ def get_foreign_key(self): return self._foreign_key def get_plain_foreign_key(self): - segments = self.get_foreign_key().split('.') + segments = self.get_foreign_key().split(".") return segments[-1] @@ -323,4 +326,4 @@ def get_parent_key(self): return self._parent.get_attribute(self._local_key) def get_qualified_parent_key_name(self): - return '%s.%s' % (self._parent.get_table(), self._local_key) + return "%s.%s" % (self._parent.get_table(), self._local_key) diff --git a/orator/orm/relations/morph_many.py b/orator/orm/relations/morph_many.py index bd5fdc70..1450ea6e 100644 --- a/orator/orm/relations/morph_many.py +++ b/orator/orm/relations/morph_many.py @@ -5,7 +5,6 @@ class MorphMany(MorphOneOrMany): - def get_results(self): """ Get the results of the relationship. @@ -20,7 +19,9 @@ def init_relation(self, models, relation): :type relation: str """ for model in models: - model.set_relation(relation, Result(self._related.new_collection(), self, model)) + model.set_relation( + relation, Result(self._related.new_collection(), self, model) + ) return models diff --git a/orator/orm/relations/morph_one.py b/orator/orm/relations/morph_one.py index 96c9e73d..fac1814b 100644 --- a/orator/orm/relations/morph_one.py +++ b/orator/orm/relations/morph_one.py @@ -5,7 +5,6 @@ class MorphOne(MorphOneOrMany): - def get_results(self): """ Get the results of the relationship. diff --git a/orator/orm/relations/morph_one_or_many.py b/orator/orm/relations/morph_one_or_many.py index acaee756..dcfba3f8 100644 --- a/orator/orm/relations/morph_one_or_many.py +++ b/orator/orm/relations/morph_one_or_many.py @@ -4,7 +4,6 @@ class MorphOneOrMany(HasOneOrMany): - def __init__(self, query, parent, morph_type, foreign_key, local_key): """ :type query: orator.orm.Builder @@ -84,7 +83,7 @@ def find_or_new(self, id, columns=None): :rtype: Collection or Model """ if columns is None: - columns = ['*'] + columns = ["*"] instance = self._query.find(id, columns) @@ -185,7 +184,7 @@ def get_morph_type(self): return self._morph_type def get_plain_morph_type(self): - return self._morph_type.split('.')[-1] + return self._morph_type.split(".")[-1] def get_morph_name(self): return self._morph_name @@ -196,5 +195,5 @@ def _new_instance(self, parent): parent, self._morph_type, self._foreign_key, - self._local_key + self._local_key, ) diff --git a/orator/orm/relations/morph_to.py b/orator/orm/relations/morph_to.py index 46f4b430..323f1e74 100644 --- a/orator/orm/relations/morph_to.py +++ b/orator/orm/relations/morph_to.py @@ -7,7 +7,6 @@ class MorphTo(BelongsTo): - def __init__(self, query, parent, foreign_key, other_key, type, relation): """ :type query: orator.orm.Builder @@ -87,7 +86,9 @@ def associate(self, model): self._parent.set_attribute(self._foreign_key, model.get_key()) self._parent.set_attribute(self._morph_type, model.get_morph_name()) - return self._parent.set_relation(self._relation, Result(model, self, self._parent)) + return self._parent.set_relation( + self._relation, Result(model, self, self._parent) + ) def get_eager(self): """ @@ -114,11 +115,7 @@ def _match_to_morph_parents(self, type, results): if result.get_key() in self._dictionary.get(type, []): for model in self._dictionary[type][result.get_key()]: model.set_relation( - self._relation, - Result( - result, self, model, - related=result - ) + self._relation, Result(result, self, model, related=result) ) def _get_results_by_type(self, type): @@ -151,9 +148,11 @@ def _gather_keys_by_type(self, type): """ foreign = self._foreign_key - keys = BaseCollection.make(list(self._dictionary[type].values()))\ - .map(lambda models: getattr(models[0], foreign))\ + keys = ( + BaseCollection.make(list(self._dictionary[type].values())) + .map(lambda models: getattr(models[0], foreign)) .unique() + ) return keys @@ -193,5 +192,5 @@ def _new_instance(self, model, related=None): self._foreign_key, self._other_key if not related else related.get_key_name(), self._morph_type, - self._relation + self._relation, ) diff --git a/orator/orm/relations/morph_to_many.py b/orator/orm/relations/morph_to_many.py index 599644f0..85f8fd8f 100644 --- a/orator/orm/relations/morph_to_many.py +++ b/orator/orm/relations/morph_to_many.py @@ -4,9 +4,17 @@ class MorphToMany(BelongsToMany): - - def __init__(self, query, parent, name, table, - foreign_key, other_key, relation_name=None, inverse=False): + def __init__( + self, + query, + parent, + name, + table, + foreign_key, + other_key, + relation_name=None, + inverse=False, + ): """ :param query: A Builder instance :type query: elquent.orm.Builder @@ -30,12 +38,13 @@ def __init__(self, query, parent, name, table, """ self._name = name self._inverse = inverse - self._morph_type = name + '_type' - self._morph_name = query.get_model().get_morph_name() if inverse else parent.get_morph_name() + self._morph_type = name + "_type" + self._morph_name = ( + query.get_model().get_morph_name() if inverse else parent.get_morph_name() + ) super(MorphToMany, self).__init__( - query, parent, table, - foreign_key, other_key, relation_name + query, parent, table, foreign_key, other_key, relation_name ) def _set_where(self): @@ -47,7 +56,7 @@ def _set_where(self): """ super(MorphToMany, self)._set_where() - self._query.where('%s.%s' % (self._table, self._morph_type), self._morph_name) + self._query.where("%s.%s" % (self._table, self._morph_type), self._morph_name) def get_relation_count_query(self, query, parent): """ @@ -60,7 +69,7 @@ def get_relation_count_query(self, query, parent): """ query = super(MorphToMany, self).get_relation_count_query(query, parent) - return query.where('%s.%s' % (self._table, self._morph_type), self._morph_name) + return query.where("%s.%s" % (self._table, self._morph_type), self._morph_name) def add_eager_constraints(self, models): """ @@ -70,7 +79,7 @@ def add_eager_constraints(self, models): """ super(MorphToMany, self).add_eager_constraints(models) - self._query.where('%s.%s' % (self._table, self._morph_type), self._morph_name) + self._query.where("%s.%s" % (self._table, self._morph_type), self._morph_name) def _create_attach_record(self, id, timed): """ @@ -100,9 +109,9 @@ def new_pivot(self, attributes=None, exists=False): pivot = MorphPivot(self._parent, attributes, self._table, exists) - pivot.set_pivot_keys(self._foreign_key, self._other_key)\ - .set_morph_type(self._morph_type)\ - .set_morph_name(self._morph_name) + pivot.set_pivot_keys(self._foreign_key, self._other_key).set_morph_type( + self._morph_type + ).set_morph_name(self._morph_name) return pivot @@ -121,5 +130,5 @@ def _new_instance(self, model): self._foreign_key, self._other_key, self._relation_name, - self._inverse + self._inverse, ) diff --git a/orator/orm/relations/relation.py b/orator/orm/relations/relation.py index 4cfc5f96..7a311b73 100644 --- a/orator/orm/relations/relation.py +++ b/orator/orm/relations/relation.py @@ -93,7 +93,8 @@ def raw_update(self, attributes=None): if attributes is None: attributes = {} - return self._query.update(attributes) + if self._query is not None: + return self._query.update(attributes) def get_relation_count_query(self, query, parent): """ @@ -104,11 +105,11 @@ def get_relation_count_query(self, query, parent): :rtype: Builder """ - query.select(QueryExpression('COUNT(*)')) + query.select(QueryExpression("COUNT(*)")) key = self.wrap(self.get_qualified_parent_key_name()) - return query.where(self.get_has_compare_key(), '=', QueryExpression(key)) + return query.where(self.get_has_compare_key(), "=", QueryExpression(key)) @classmethod @contextmanager @@ -141,7 +142,14 @@ def get_keys(self, models, key=None): :rtype: list """ - return list(set(map(lambda value: value.get_attribute(key) if key else value.get_key(), models))) + return list( + set( + map( + lambda value: value.get_attribute(key) if key else value.get_key(), + models, + ) + ) + ) def get_query(self): return self._query diff --git a/orator/orm/relations/wrapper.py b/orator/orm/relations/wrapper.py index ae237446..2949c00e 100644 --- a/orator/orm/relations/wrapper.py +++ b/orator/orm/relations/wrapper.py @@ -41,7 +41,6 @@ def __repr__(self): class BelongsToManyWrapper(Wrapper): - def with_timestamps(self): self._relation.with_timestamps() diff --git a/orator/orm/scopes/scope.py b/orator/orm/scopes/scope.py index 9293285a..fae90f84 100644 --- a/orator/orm/scopes/scope.py +++ b/orator/orm/scopes/scope.py @@ -2,7 +2,6 @@ class Scope(object): - def apply(self, builder, model): """ Apply the scope to a given query builder. diff --git a/orator/orm/scopes/soft_deleting.py b/orator/orm/scopes/soft_deleting.py index edfca0ca..0e8efb2e 100644 --- a/orator/orm/scopes/soft_deleting.py +++ b/orator/orm/scopes/soft_deleting.py @@ -5,7 +5,7 @@ class SoftDeletingScope(Scope): - _extensions = ['force_delete', 'restore', 'with_trashed', 'only_trashed'] + _extensions = ["force_delete", "restore", "with_trashed", "only_trashed"] def apply(self, builder, model): """ @@ -29,7 +29,7 @@ def extend(self, builder): :type builder: orator.orm.builder.Builder """ for extension in self._extensions: - getattr(self, '_add_%s' % extension)(builder) + getattr(self, "_add_%s" % extension)(builder) builder.on_delete(self._on_delete) @@ -42,9 +42,7 @@ def _on_delete(self, builder): """ column = self._get_deleted_at_column(builder) - return builder.update({ - column: builder.get_model().fresh_timestamp() - }) + return builder.update({column: builder.get_model().fresh_timestamp()}) def _get_deleted_at_column(self, builder): """ @@ -67,7 +65,7 @@ def _add_force_delete(self, builder): :param builder: The query builder :type builder: orator.orm.builder.Builder """ - builder.macro('force_delete', self._force_delete) + builder.macro("force_delete", self._force_delete) def _force_delete(self, builder): """ @@ -85,7 +83,7 @@ def _add_restore(self, builder): :param builder: The query builder :type builder: orator.orm.builder.Builder """ - builder.macro('restore', self._restore) + builder.macro("restore", self._restore) def _restore(self, builder): """ @@ -96,9 +94,7 @@ def _restore(self, builder): """ builder.with_trashed() - return builder.update({ - builder.get_model().get_deleted_at_column(): None - }) + return builder.update({builder.get_model().get_deleted_at_column(): None}) def _add_with_trashed(self, builder): """ @@ -107,7 +103,7 @@ def _add_with_trashed(self, builder): :param builder: The query builder :type builder: orator.orm.builder.Builder """ - builder.macro('with_trashed', self._with_trashed) + builder.macro("with_trashed", self._with_trashed) def _with_trashed(self, builder): """ @@ -127,7 +123,7 @@ def _add_only_trashed(self, builder): :param builder: The query builder :type builder: orator.orm.builder.Builder """ - builder.macro('only_trashed', self._only_trashed) + builder.macro("only_trashed", self._only_trashed) def _only_trashed(self, builder): """ diff --git a/orator/orm/utils.py b/orator/orm/utils.py index 449b4146..2eae9ac2 100644 --- a/orator/orm/utils.py +++ b/orator/orm/utils.py @@ -6,14 +6,19 @@ from .builder import Builder from ..query import QueryBuilder from .relations import ( - HasOne, HasMany, HasManyThrough, - BelongsTo, BelongsToMany, - MorphOne, MorphMany, MorphTo, MorphToMany + HasOne, + HasMany, + HasManyThrough, + BelongsTo, + BelongsToMany, + MorphOne, + MorphMany, + MorphTo, + MorphToMany, ) class accessor(object): - def __init__(self, accessor_, attribute=None): self.accessor = accessor_ self.mutator_ = None @@ -48,7 +53,6 @@ def mutator(self, f): class mutator(object): - def __init__(self, mutator_, attribute=None): self.mutator = mutator_ self.accessor_ = None @@ -73,7 +77,6 @@ def accessor(self, f): class column(object): - def __init__(self, property_, attribute=None): self.property = property_ self.mutator_ = None @@ -231,7 +234,7 @@ def _get(self, instance): self._foreign_key, self._local_key, self._relation, - _wrapped=False + _wrapped=False, ) @@ -242,9 +245,11 @@ class morph_one(relation): relation_class = MorphOne - def __init__(self, name, type_column=None, id_column=None, local_key=None, relation=None): + def __init__( + self, name, type_column=None, id_column=None, local_key=None, relation=None + ): if isinstance(name, (types.FunctionType, types.MethodType)): - raise RuntimeError('morph_one relation requires a name') + raise RuntimeError("morph_one relation requires a name") self._name = name self._type_column = type_column @@ -255,10 +260,13 @@ def __init__(self, name, type_column=None, id_column=None, local_key=None, relat def _get(self, instance): return instance.morph_one( - self._related, self._name, - self._type_column, self._id_column, - self._local_key, self._relation, - _wrapped=False + self._related, + self._name, + self._type_column, + self._id_column, + self._local_key, + self._relation, + _wrapped=False, ) @@ -287,7 +295,7 @@ def _get(self, instance): self._foreign_key, self._other_key, self._relation, - _wrapped=False + _wrapped=False, ) def _set(self, relation): @@ -318,9 +326,7 @@ def __init__(self, name=None, type_column=None, id_column=None): def _get(self, instance): return instance.morph_to( - self._relation, - self._type_column, self._id_column, - _wrapped=False + self._relation, self._type_column, self._id_column, _wrapped=False ) @@ -349,7 +355,7 @@ def _get(self, instance): self._foreign_key, self._local_key, self._relation, - _wrapped=False + _wrapped=False, ) @@ -362,7 +368,9 @@ class has_many_through(relation): def __init__(self, through, first_key=None, second_key=None, relation=None): if isinstance(through, (types.FunctionType, types.MethodType)): - raise RuntimeError('has_many_through relation requires the through parameter') + raise RuntimeError( + "has_many_through relation requires the through parameter" + ) self._through = through self._first_key = first_key @@ -377,7 +385,7 @@ def _get(self, instance): self._first_key, self._second_key, self._relation, - _wrapped=False + _wrapped=False, ) @@ -388,9 +396,11 @@ class morph_many(relation): relation_class = MorphMany - def __init__(self, name, type_column=None, id_column=None, local_key=None, relation=None): + def __init__( + self, name, type_column=None, id_column=None, local_key=None, relation=None + ): if isinstance(name, (types.FunctionType, types.MethodType)): - raise RuntimeError('morph_many relation requires a name') + raise RuntimeError("morph_many relation requires a name") self._name = name self._type_column = type_column @@ -401,10 +411,13 @@ def __init__(self, name, type_column=None, id_column=None, local_key=None, relat def _get(self, instance): return instance.morph_many( - self._related, self._name, - self._type_column, self._id_column, - self._local_key, self._relation, - _wrapped=False + self._related, + self._name, + self._type_column, + self._id_column, + self._local_key, + self._relation, + _wrapped=False, ) @@ -415,8 +428,15 @@ class belongs_to_many(relation): relation_class = BelongsToMany - def __init__(self, table=None, foreign_key=None, other_key=None, - relation=None, with_timestamps=False, with_pivot=None): + def __init__( + self, + table=None, + foreign_key=None, + other_key=None, + relation=None, + with_timestamps=False, + with_pivot=None, + ): if isinstance(table, (types.FunctionType, types.MethodType)): func = table table = None @@ -439,7 +459,7 @@ def _get(self, instance): self._foreign_key, self._other_key, self._relation, - _wrapped=False + _wrapped=False, ) if self._timestamps: @@ -458,9 +478,11 @@ class morph_to_many(relation): relation_class = MorphToMany - def __init__(self, name, table=None, foreign_key=None, other_key=None, relation=None): + def __init__( + self, name, table=None, foreign_key=None, other_key=None, relation=None + ): if isinstance(name, (types.FunctionType, types.MethodType)): - raise RuntimeError('morph_to_many relation required a name') + raise RuntimeError("morph_to_many relation required a name") self._name = name self._table = table @@ -477,7 +499,7 @@ def _get(self, instance): self._foreign_key, self._other_key, relation=self._relation, - _wrapped=False + _wrapped=False, ) @@ -488,9 +510,11 @@ class morphed_by_many(relation): relation_class = MorphToMany - def __init__(self, name, table=None, foreign_key=None, other_key=None, relation=None): + def __init__( + self, name, table=None, foreign_key=None, other_key=None, relation=None + ): if isinstance(foreign_key, (types.FunctionType, types.MethodType)): - raise RuntimeError('morphed_by_many relation requires a name') + raise RuntimeError("morphed_by_many relation requires a name") self._name = name self._table = table @@ -507,5 +531,5 @@ def _get(self, instance): self._foreign_key, self._other_key, self._relation, - _wrapped=False + _wrapped=False, ) diff --git a/orator/pagination/length_aware_paginator.py b/orator/pagination/length_aware_paginator.py index 128154b4..c9a910ed 100644 --- a/orator/pagination/length_aware_paginator.py +++ b/orator/pagination/length_aware_paginator.py @@ -9,7 +9,6 @@ class LengthAwarePaginator(BasePaginator): - def __init__(self, items, total, per_page, current_page=None, options=None): """ Constructor diff --git a/orator/pagination/paginator.py b/orator/pagination/paginator.py index f3041c98..3b0f04f7 100644 --- a/orator/pagination/paginator.py +++ b/orator/pagination/paginator.py @@ -6,7 +6,6 @@ class Paginator(BasePaginator): - def __init__(self, items, per_page, current_page=None, options=None): """ Constructor @@ -59,7 +58,7 @@ def _check_for_more_pages(self): """ self._has_more = len(self._items) > self.per_page - self._items = self._items[0:self.per_page] + self._items = self._items[0 : self.per_page] def has_more_pages(self): """ diff --git a/orator/query/builder.py b/orator/query/builder.py index 60f59c10..6e8aaa6c 100644 --- a/orator/query/builder.py +++ b/orator/query/builder.py @@ -2,8 +2,11 @@ import re import copy +import datetime + from itertools import chain from collections import OrderedDict + from .expression import QueryExpression from .join_clause import JoinClause from ..pagination import Paginator, LengthAwarePaginator @@ -15,12 +18,32 @@ class QueryBuilder(object): _operators = [ - '=', '<', '>', '<=', '>=', '<>', '!=', - 'like', 'like binary', 'not like', 'between', 'ilike', - '&', '|', '^', '<<', '>>', - 'rlike', 'regexp', 'not regexp', - '~', '~*', '!~', '!~*', 'similar to', - 'not similar to', + "=", + "<", + ">", + "<=", + ">=", + "<>", + "!=", + "like", + "like binary", + "not like", + "between", + "ilike", + "&", + "|", + "^", + "<<", + ">>", + "rlike", + "regexp", + "not regexp", + "~", + "~*", + "!~", + "!~*", + "similar to", + "not similar to", ] def __init__(self, connection, grammar, processor): @@ -40,13 +63,13 @@ def __init__(self, connection, grammar, processor): self._processor = processor self._connection = connection self._bindings = OrderedDict() - for type in ['select', 'join', 'where', 'having', 'order']: + for type in ["select", "join", "where", "having", "order"]: self._bindings[type] = [] self.aggregate_ = None self.columns = [] self.distinct_ = False - self.from__ = '' + self.from__ = "" self.joins = [] self.wheres = [] self.groups = [] @@ -75,7 +98,7 @@ def select(self, *columns): :rtype: QueryBuilder """ if not columns: - columns = ['*'] + columns = ["*"] self.columns = list(columns) @@ -97,7 +120,7 @@ def select_raw(self, expression, bindings=None): self.add_select(QueryExpression(expression)) if bindings: - self.add_binding(bindings, 'select') + self.add_binding(bindings, "select") return self @@ -121,9 +144,11 @@ def select_sub(self, query, as_): elif isinstance(query, basestring): bindings = [] else: - raise ArgumentError('Invalid subselect') + raise ArgumentError("Invalid subselect") - return self.select_raw('(%s) AS %s' % (query, self._grammar.wrap(as_)), bindings) + return self.select_raw( + "(%s) AS %s" % (query, self._grammar.wrap(as_)), bindings + ) def add_select(self, *column): """ @@ -167,8 +192,7 @@ def from_(self, table): return self - def join(self, table, one=None, - operator=None, two=None, type='inner', where=False): + def join(self, table, one=None, operator=None, two=None, type="inner", where=False): """ Add a join clause to the query @@ -201,13 +225,11 @@ def join(self, table, one=None, join = JoinClause(table, type) - self.joins.append(join.on( - one, operator, two, 'and', where - )) + self.joins.append(join.on(one, operator, two, "and", where)) return self - def join_where(self, table, one, operator, two, type='inner'): + def join_where(self, table, one, operator, two, type="inner"): """ Add a "join where" clause to the query @@ -251,9 +273,9 @@ def left_join(self, table, one=None, operator=None, two=None): :rtype: QueryBuilder """ if isinstance(table, JoinClause): - table.type = 'left' + table.type = "left" - return self.join(table, one, operator, two, 'left') + return self.join(table, one, operator, two, "left") def left_join_where(self, table, one, operator, two): """ @@ -274,7 +296,7 @@ def left_join_where(self, table, one, operator, two): :return: The current QueryBuilder instance :rtype: QueryBuilder """ - return self.join_where(table, one, operator, two, 'left') + return self.join_where(table, one, operator, two, "left") def right_join(self, table, one=None, operator=None, two=None): """ @@ -296,9 +318,9 @@ def right_join(self, table, one=None, operator=None, two=None): :rtype: QueryBuilder """ if isinstance(table, JoinClause): - table.type = 'right' + table.type = "right" - return self.join(table, one, operator, two, 'right') + return self.join(table, one, operator, two, "right") def right_join_where(self, table, one, operator, two): """ @@ -319,9 +341,9 @@ def right_join_where(self, table, one, operator, two): :return: The current QueryBuilder instance :rtype: QueryBuilder """ - return self.join_where(table, one, operator, two, 'right') + return self.join_where(table, one, operator, two, "right") - def where(self, column, operator=Null(), value=None, boolean='and'): + def where(self, column, operator=Null(), value=None, boolean="and"): """ Add a where clause to the query @@ -346,7 +368,7 @@ def where(self, column, operator=Null(), value=None, boolean='and'): if isinstance(column, dict): nested = self.new_query() for key, value in column.items(): - nested.where(key, '=', value) + nested.where(key, "=", value) return self.where_nested(nested, boolean) @@ -359,89 +381,84 @@ def where(self, column, operator=Null(), value=None, boolean='and'): if isinstance(condition, list) and len(condition) == 3: nested.where(condition[0], condition[1], condition[2]) else: - raise ArgumentError('Invalid conditions in where() clause') + raise ArgumentError("Invalid conditions in where() clause") return self.where_nested(nested, boolean) if value is None: if not isinstance(operator, Null): value = operator - operator = '=' + operator = "=" else: - raise ArgumentError('Value must be provided') + raise ArgumentError("Value must be provided") if operator not in self._operators: value = operator - operator = '=' + operator = "=" if isinstance(value, QueryBuilder): return self._where_sub(column, operator, value, boolean) if value is None: - return self.where_null(column, boolean, operator != '=') - - type = 'basic' - - self.wheres.append({ - 'type': type, - 'column': column, - 'operator': operator, - 'value': value, - 'boolean': boolean - }) + return self.where_null(column, boolean, operator != "=") + + type = "basic" + + self.wheres.append( + { + "type": type, + "column": column, + "operator": operator, + "value": value, + "boolean": boolean, + } + ) if not isinstance(value, QueryExpression): - self.add_binding(value, 'where') + self.add_binding(value, "where") return self def or_where(self, column, operator=None, value=None): - return self.where(column, operator, value, 'or') + return self.where(column, operator, value, "or") def _invalid_operator_and_value(self, operator, value): is_operator = operator in self._operators - return is_operator and operator != '=' and value is None + return is_operator and operator != "=" and value is None - def where_raw(self, sql, bindings=None, boolean='and'): - type = 'raw' + def where_raw(self, sql, bindings=None, boolean="and"): + type = "raw" - self.wheres.append({ - 'type': type, - 'sql': sql, - 'boolean': boolean - }) + self.wheres.append({"type": type, "sql": sql, "boolean": boolean}) - self.add_binding(bindings, 'where') + self.add_binding(bindings, "where") return self def or_where_raw(self, sql, bindings=None): - return self.where_raw(sql, bindings, 'or') + return self.where_raw(sql, bindings, "or") - def where_between(self, column, values, boolean='and', negate=False): - type = 'between' + def where_between(self, column, values, boolean="and", negate=False): + type = "between" - self.wheres.append({ - 'column': column, - 'type': type, - 'boolean': boolean, - 'not': negate - }) + self.wheres.append( + {"column": column, "type": type, "boolean": boolean, "not": negate} + ) - self.add_binding(values, 'where') + self.add_binding(values, "where") return self def or_where_between(self, column, values): - return self.where_between(column, values, 'or') + return self.where_between(column, values, "or") - def where_not_between(self, column, values, boolean='and'): + def where_not_between(self, column, values, boolean="and"): return self.where_between(column, values, boolean, True) def or_where_not_between(self, column, values): - return self.where_not_between(column, values, 'or') + return self.where_not_between(column, values, "or") - def where_nested(self, query, boolean='and'): + def where_nested(self, query, boolean="and"): query.from_(self.from__) return self.add_nested_where_query(query, boolean) @@ -456,36 +473,34 @@ def for_nested_where(self): return query.from_(self.from__) - def add_nested_where_query(self, query, boolean='and'): + def add_nested_where_query(self, query, boolean="and"): if len(query.wheres): - type = 'nested' + type = "nested" - self.wheres.append({ - 'type': type, - 'query': query, - 'boolean': boolean - }) + self.wheres.append({"type": type, "query": query, "boolean": boolean}) self.merge_bindings(query) return self def _where_sub(self, column, operator, query, boolean): - type = 'sub' - - self.wheres.append({ - 'type': type, - 'column': column, - 'operator': operator, - 'query': query, - 'boolean': boolean - }) + type = "sub" + + self.wheres.append( + { + "type": type, + "column": column, + "operator": operator, + "query": query, + "boolean": boolean, + } + ) self.merge_bindings(query) return self - def where_exists(self, query, boolean='and', negate=False): + def where_exists(self, query, boolean="and", negate=False): """ Add an exists clause to the query. @@ -499,15 +514,11 @@ def where_exists(self, query, boolean='and', negate=False): :rtype: QueryBuilder """ if negate: - type = 'not_exists' + type = "not_exists" else: - type = 'exists' + type = "exists" - self.wheres.append({ - 'type': type, - 'query': query, - 'boolean': boolean - }) + self.wheres.append({"type": type, "query": query, "boolean": boolean}) self.merge_bindings(query) @@ -524,9 +535,9 @@ def or_where_exists(self, query, negate=False): :rtype: QueryBuilder """ - return self.where_exists(query, 'or', negate) + return self.where_exists(query, "or", negate) - def where_not_exists(self, query, boolean='and'): + def where_not_exists(self, query, boolean="and"): """ Add a where not exists clause to the query. @@ -550,11 +561,11 @@ def or_where_not_exists(self, query): """ return self.or_where_exists(query, True) - def where_in(self, column, values, boolean='and', negate=False): + def where_in(self, column, values, boolean="and", negate=False): if negate: - type = 'not_in' + type = "not_in" else: - type = 'in' + type = "in" if isinstance(values, QueryBuilder): return self._where_in_sub(column, values, boolean, negate) @@ -562,25 +573,22 @@ def where_in(self, column, values, boolean='and', negate=False): if isinstance(values, Collection): values = values.all() - self.wheres.append({ - 'type': type, - 'column': column, - 'values': values, - 'boolean': boolean - }) + self.wheres.append( + {"type": type, "column": column, "values": values, "boolean": boolean} + ) - self.add_binding(values, 'where') + self.add_binding(values, "where") return self def or_where_in(self, column, values): - return self.where_in(column, values, 'or') + return self.where_in(column, values, "or") - def where_not_in(self, column, values, boolean='and'): + def where_not_in(self, column, values, boolean="and"): return self.where_in(column, values, boolean, True) def or_where_not_in(self, column, values): - return self.where_not_in(column, values, 'or') + return self.where_not_in(column, values, "or") def _where_in_sub(self, column, query, boolean, negate=False): """ @@ -602,79 +610,74 @@ def _where_in_sub(self, column, query, boolean, negate=False): :rtype: QueryBuilder """ if negate: - type = 'not_in_sub' + type = "not_in_sub" else: - type = 'in_sub' + type = "in_sub" - self.wheres.append({ - 'type': type, - 'column': column, - 'query': query, - 'boolean': boolean - }) + self.wheres.append( + {"type": type, "column": column, "query": query, "boolean": boolean} + ) self.merge_bindings(query) return self - def where_null(self, column, boolean='and', negate=False): + def where_null(self, column, boolean="and", negate=False): if negate: - type = 'not_null' + type = "not_null" else: - type = 'null' + type = "null" - self.wheres.append({ - 'type': type, - 'column': column, - 'boolean': boolean - }) + self.wheres.append({"type": type, "column": column, "boolean": boolean}) return self def or_where_null(self, column): - return self.where_null(column, 'or') + return self.where_null(column, "or") - def where_not_null(self, column, boolean='and'): + def where_not_null(self, column, boolean="and"): return self.where_null(column, boolean, True) def or_where_not_null(self, column): - return self.where_not_null(column, 'or') + return self.where_not_null(column, "or") - def where_date(self, column, operator, value, boolean='and'): - return self._add_date_based_where('date', column, operator, value, boolean) + def where_date(self, column, operator, value, boolean="and"): + return self._add_date_based_where("date", column, operator, value, boolean) - def where_day(self, column, operator, value, boolean='and'): - return self._add_date_based_where('day', column, operator, value, boolean) + def where_day(self, column, operator, value, boolean="and"): + return self._add_date_based_where("day", column, operator, value, boolean) - def where_month(self, column, operator, value, boolean='and'): - return self._add_date_based_where('month', column, operator, value, boolean) + def where_month(self, column, operator, value, boolean="and"): + return self._add_date_based_where("month", column, operator, value, boolean) - def where_year(self, column, operator, value, boolean='and'): - return self._add_date_based_where('year', column, operator, value, boolean) + def where_year(self, column, operator, value, boolean="and"): + return self._add_date_based_where("year", column, operator, value, boolean) - def _add_date_based_where(self, type, column, operator, value, boolean='and'): - self.wheres.append({ - 'type': type, - 'column': column, - 'boolean': boolean, - 'operator': operator, - 'value': value - }) + def _add_date_based_where(self, type, column, operator, value, boolean="and"): + self.wheres.append( + { + "type": type, + "column": column, + "boolean": boolean, + "operator": operator, + "value": value, + } + ) - self.add_binding(value, 'where') + self.add_binding(value, "where") def dynamic_where(self, method): finder = method[6:] def dynamic_where(*parameters): - segments = re.split('_(and|or)_(?=[a-z])', finder, 0, re.I) + segments = re.split("_(and|or)_(?=[a-z])", finder, 0, re.I) - connector = 'and' + connector = "and" index = 0 for segment in segments: - if segment.lower() != 'and' and segment.lower() != 'or': + if segment.lower() != "and" and segment.lower() != "or": self._add_dynamic(segment, connector, parameters, index) index += 1 @@ -686,7 +689,7 @@ def dynamic_where(*parameters): return dynamic_where def _add_dynamic(self, segment, connector, parameters, index): - self.where(segment, '=', parameters[index], connector) + self.where(segment, "=", parameters[index], connector) def group_by(self, *columns): """ @@ -703,7 +706,7 @@ def group_by(self, *columns): return self - def having(self, column, operator=None, value=None, boolean='and'): + def having(self, column, operator=None, value=None, boolean="and"): """ Add a "having" clause to the query @@ -722,18 +725,20 @@ def having(self, column, operator=None, value=None, boolean='and'): :return: The current QueryBuilder instance :rtype: QueryBuilder """ - type = 'basic' + type = "basic" - self.havings.append({ - 'type': type, - 'column': column, - 'operator': operator, - 'value': value, - 'boolean': boolean - }) + self.havings.append( + { + "type": type, + "column": column, + "operator": operator, + "value": value, + "boolean": boolean, + } + ) if not isinstance(value, QueryExpression): - self.add_binding(value, 'having') + self.add_binding(value, "having") return self @@ -753,9 +758,9 @@ def or_having(self, column, operator=None, value=None): :return: The current QueryBuilder instance :rtype: QueryBuilder """ - return self.having(column, operator, value, 'or') + return self.having(column, operator, value, "or") - def having_raw(self, sql, bindings=None, boolean='and'): + def having_raw(self, sql, bindings=None, boolean="and"): """ Add a raw having clause to the query @@ -771,15 +776,11 @@ def having_raw(self, sql, bindings=None, boolean='and'): :return: The current QueryBuilder instance :rtype: QueryBuilder """ - type = 'raw' + type = "raw" - self.havings.append({ - 'type': type, - 'sql': sql, - 'boolean': boolean - }) + self.havings.append({"type": type, "sql": sql, "boolean": boolean}) - self.add_binding(bindings, 'having') + self.add_binding(bindings, "having") return self @@ -796,9 +797,9 @@ def or_having_raw(self, sql, bindings=None): :return: The current QueryBuilder instance :rtype: QueryBuilder """ - return self.having_raw(sql, bindings, 'or') + return self.having_raw(sql, bindings, "or") - def order_by(self, column, direction='asc'): + def order_by(self, column, direction="asc"): """ Add a "order by" clause to the query @@ -812,23 +813,20 @@ def order_by(self, column, direction='asc'): :rtype: QueryBuilder """ if self.unions: - prop = 'union_orders' + prop = "union_orders" else: - prop = 'orders' + prop = "orders" - if direction.lower() == 'asc': - direction = 'asc' + if direction.lower() == "asc": + direction = "asc" else: - direction = 'desc' + direction = "desc" - getattr(self, prop).append({ - 'column': column, - 'direction': direction - }) + getattr(self, prop).append({"column": column, "direction": direction}) return self - def latest(self, column='created_at'): + def latest(self, column="created_at"): """ Add an "order by" clause for a timestamp to the query in descending order @@ -839,9 +837,9 @@ def latest(self, column='created_at'): :return: The current QueryBuilder instance :rtype: QueryBuilder """ - return self.order_by(column, 'desc') + return self.order_by(column, "desc") - def oldest(self, column='created_at'): + def oldest(self, column="created_at"): """ Add an "order by" clause for a timestamp to the query in ascending order @@ -852,7 +850,7 @@ def oldest(self, column='created_at'): :return: The current QueryBuilder instance :rtype: QueryBuilder """ - return self.order_by(column, 'asc') + return self.order_by(column, "asc") def order_by_raw(self, sql, bindings=None): """ @@ -870,22 +868,19 @@ def order_by_raw(self, sql, bindings=None): if bindings is None: bindings = [] - type = 'raw' + type = "raw" - self.orders.append({ - 'type': type, - 'sql': sql - }) + self.orders.append({"type": type, "sql": sql}) - self.add_binding(bindings, 'order') + self.add_binding(bindings, "order") return self def offset(self, value): if self.unions: - prop = 'union_offset' + prop = "union_offset" else: - prop = 'offset_' + prop = "offset_" setattr(self, prop, max(0, value)) @@ -896,9 +891,9 @@ def skip(self, value): def limit(self, value): if self.unions: - prop = 'union_limit' + prop = "union_limit" else: - prop = 'limit_' + prop = "limit_" if value is None or value > 0: setattr(self, prop, value) @@ -924,10 +919,7 @@ def union(self, query, all=False): :return: The query :rtype: QueryBuilder """ - self.unions.append({ - 'query': query, - 'all': all - }) + self.unions.append({"query": query, "all": all}) return self.merge_bindings(query) @@ -998,9 +990,9 @@ def find(self, id, columns=None): :rtype: mixed """ if not columns: - columns = ['*'] + columns = ["*"] - return self.where('id', '=', id).first(1, columns) + return self.where("id", "=", id).first(1, columns) def pluck(self, column): """ @@ -1033,7 +1025,7 @@ def first(self, limit=1, columns=None): :rtype: mixed """ if not columns: - columns = ['*'] + columns = ["*"] return self.take(limit).get(columns).first() @@ -1048,7 +1040,7 @@ def get(self, columns=None): :rtype: Collection """ if not columns: - columns = ['*'] + columns = ["*"] original = self.columns @@ -1069,9 +1061,7 @@ def _run_select(self): :rtype: list """ return self._connection.select( - self.to_sql(), - self.get_bindings(), - not self._use_write_connection + self.to_sql(), self.get_bindings(), not self._use_write_connection ) def paginate(self, per_page=15, current_page=None, columns=None): @@ -1091,7 +1081,7 @@ def paginate(self, per_page=15, current_page=None, columns=None): :rtype: LengthAwarePaginator """ if columns is None: - columns = ['*'] + columns = ["*"] page = current_page or Paginator.resolve_current_page() @@ -1118,7 +1108,7 @@ def simple_paginate(self, per_page=15, current_page=None, columns=None): :rtype: Paginator """ if columns is None: - columns = ['*'] + columns = ["*"] page = current_page or Paginator.resolve_current_page() @@ -1136,14 +1126,20 @@ def get_count_for_pagination(self): return total def _backup_fields_for_count(self): - for field in ['orders', 'limit', 'offset']: - self._backups[field] = getattr(self, field) + for field, binding in [("orders", "order"), ("limit", None), ("offset", None)]: + self._backups[field] = {} + self._backups[field]["query"] = getattr(self, field) + if binding is not None: + self._backups[field]["binding"] = self.get_raw_bindings()[binding] + self.set_bindings([], binding) setattr(self, field, None) def _restore_fields_for_count(self): - for field in ['orders', 'limit', 'offset']: - setattr(self, field, self._backups[field]) + for field, binding in [("orders", "order"), ("limit", None), ("offset", None)]: + setattr(self, field, self._backups[field]["query"]) + if binding is not None and self._backups[field]["binding"] is not None: + self.add_binding(self._backups[field]["binding"], binding) self._backups = {} @@ -1157,15 +1153,10 @@ def chunk(self, count): :return: The current chunk :rtype: list """ - page = 1 - results = self.for_page(page, count).get() - - while not results.is_empty(): - yield results - - page += 1 - - results = self.for_page(page, count).get() + for chunk in self._connection.select_many( + count, self.to_sql(), self.get_bindings(), not self._use_write_connection + ): + yield chunk def lists(self, column, key=None): """ @@ -1211,16 +1202,16 @@ def _get_list_select(self, column, key=None): select = [] for elem in elements: - dot = elem.find('.') + dot = elem.find(".") if dot >= 0: - select.append(column[dot + 1:]) + select.append(column[dot + 1 :]) else: select.append(elem) return select - def implode(self, column, glue=''): + def implode(self, column, glue=""): """ Concatenate values of a given column as a string. @@ -1260,10 +1251,13 @@ def count(self, *columns): :return: The count :rtype: int """ + if not columns and self.distinct_: + columns = self.columns + if not columns: - columns = ['*'] + columns = ["*"] - return int(self.aggregate('count', *columns)) + return int(self.aggregate("count", *columns)) def min(self, column): """ @@ -1275,7 +1269,7 @@ def min(self, column): :return: The min :rtype: int """ - return self.aggregate('min', *[column]) + return self.aggregate("min", *[column]) def max(self, column): """ @@ -1288,9 +1282,9 @@ def max(self, column): :rtype: int """ if not column: - columns = ['*'] + columns = ["*"] - return self.aggregate('max', *[column]) + return self.aggregate("max", *[column]) def sum(self, column): """ @@ -1302,7 +1296,7 @@ def sum(self, column): :return: The sum :rtype: int """ - return self.aggregate('sum', *[column]) + return self.aggregate("sum", *[column]) def avg(self, column): """ @@ -1315,7 +1309,7 @@ def avg(self, column): :rtype: int """ - return self.aggregate('avg', *[column]) + return self.aggregate("avg", *[column]) def aggregate(self, func, *columns): """ @@ -1331,12 +1325,9 @@ def aggregate(self, func, *columns): :rtype: mixed """ if not columns: - columns = ['*'] + columns = ["*"] - self.aggregate_ = { - 'function': func, - 'columns': columns - } + self.aggregate_ = {"function": func, "columns": columns} previous_columns = self.columns @@ -1347,7 +1338,7 @@ def aggregate(self, func, *columns): self.columns = previous_columns if len(results) > 0: - return dict((k.lower(), v) for k, v in results[0].items())['aggregate'] + return dict((k.lower(), v) for k, v in results[0].items())["aggregate"] def insert(self, _values=None, **values): """ @@ -1450,9 +1441,7 @@ def increment(self, column, amount=1, extras=None): if extras is None: extras = {} - columns = { - column: self.raw('%s + %s' % (wrapped, amount)) - } + columns = {column: self.raw("%s + %s" % (wrapped, amount))} columns.update(extras) return self.update(**columns) @@ -1478,9 +1467,7 @@ def decrement(self, column, amount=1, extras=None): if extras is None: extras = {} - columns = { - column: self.raw('%s - %s' % (wrapped, amount)) - } + columns = {column: self.raw("%s - %s" % (wrapped, amount))} columns.update(extras) return self.update(**columns) @@ -1496,7 +1483,7 @@ def delete(self, id=None): :rtype: int """ if id is not None: - self.where('id', '=', id) + self.where("id", "=", id) sql = self._grammar.compile_delete(self) @@ -1533,7 +1520,7 @@ def merge_wheres(self, wheres, bindings): :rtype: None """ self.wheres = self.wheres + wheres - self._bindings['where'] = self._bindings['where'] + bindings + self._bindings["where"] = self._bindings["where"] + bindings def _clean_bindings(self, bindings): """ @@ -1560,25 +1547,32 @@ def raw(self, value): return self._connection.raw(value) def get_bindings(self): - return list(chain(*self._bindings.values())) + bindings = [] + for value in chain(*self._bindings.values()): + if isinstance(value, datetime.date): + value = value.strftime(self._grammar.get_date_format()) + + bindings.append(value) + + return bindings def get_raw_bindings(self): return self._bindings - def set_bindings(self, bindings, type='where'): + def set_bindings(self, bindings, type="where"): if type not in self._bindings: - raise ArgumentError('Invalid binding type: %s' % type) + raise ArgumentError("Invalid binding type: %s" % type) self._bindings[type] = bindings return self - def add_binding(self, value, type='where'): + def add_binding(self, value, type="where"): if value is None: return self if type not in self._bindings: - raise ArgumentError('Invalid binding type: %s' % type) + raise ArgumentError("Invalid binding type: %s" % type) if isinstance(value, (list, tuple)): self._bindings[type] += value @@ -1606,6 +1600,7 @@ def merge(self, query): self.groups += query.groups self.havings += query.havings self.orders += query.orders + self.distinct_ = query.distinct_ if self.columns: self.columns = Collection(self.columns).unique().all() @@ -1661,7 +1656,7 @@ def use_write_connection(self): return self def __getattr__(self, item): - if item.startswith('where_'): + if item.startswith("where_"): return self.dynamic_where(item) raise AttributeError(item) @@ -1669,8 +1664,15 @@ def __getattr__(self, item): def __copy__(self): new = self.__class__(self._connection, self._grammar, self._processor) - new.__dict__.update(dict((k, copy.deepcopy(v)) for k, v - in self.__dict__.items() - if k != '_connection')) + new.__dict__.update( + dict( + (k, copy.deepcopy(v)) + for k, v in self.__dict__.items() + if k != "_connection" + ) + ) return new + + def __deepcopy__(self, memo): + return self.__copy__() diff --git a/orator/query/expression.py b/orator/query/expression.py index 911ba33f..711de337 100644 --- a/orator/query/expression.py +++ b/orator/query/expression.py @@ -2,7 +2,6 @@ class QueryExpression(object): - def __init__(self, value): self._value = value diff --git a/orator/query/grammars/grammar.py b/orator/query/grammars/grammar.py index 58e26db9..ffd1d74d 100644 --- a/orator/query/grammars/grammar.py +++ b/orator/query/grammars/grammar.py @@ -9,23 +9,23 @@ class QueryGrammar(Grammar): _select_components = [ - 'aggregate_', - 'columns', - 'from__', - 'joins', - 'wheres', - 'groups', - 'havings', - 'orders', - 'limit_', - 'offset_', - 'unions', - 'lock_' + "aggregate_", + "columns", + "from__", + "joins", + "wheres", + "groups", + "havings", + "orders", + "limit_", + "offset_", + "unions", + "lock_", ] def compile_select(self, query): if not query.columns: - query.columns = ['*'] + query.columns = ["*"] return self._concatenate(self._compile_components(query)).strip() @@ -38,20 +38,19 @@ def _compile_components(self, query): # function for the component which is responsible for making the SQL. component_value = getattr(query, component) if component_value is not None: - method = '_compile_%s' % component.replace('_', '') + method = "_compile_%s" % component.replace("_", "") sql[component] = getattr(self, method)(query, component_value) return sql def _compile_aggregate(self, query, aggregate): - column = self.columnize(aggregate['columns']) + column = self.columnize(aggregate["columns"]) - if query.distinct_ and column != '*': - column = 'DISTINCT %s' % column + if query.distinct_ and column != "*": + column = "DISTINCT %s" % column - return 'SELECT %s(%s) AS aggregate' % (aggregate['function'].upper(), - column) + return "SELECT %s(%s) AS aggregate" % (aggregate["function"].upper(), column) def _compile_columns(self, query, columns): # If the query is actually performing an aggregating select, we will let that @@ -61,19 +60,19 @@ def _compile_columns(self, query, columns): return if query.distinct_: - select = 'SELECT DISTINCT ' + select = "SELECT DISTINCT " else: - select = 'SELECT ' + select = "SELECT " - return '%s%s' % (select, self.columnize(columns)) + return "%s%s" % (select, self.columnize(columns)) def _compile_from(self, query, table): - return 'FROM %s' % self.wrap_table(table) + return "FROM %s" % self.wrap_table(table) def _compile_joins(self, query, joins): sql = [] - query.set_bindings([], 'join') + query.set_bindings([], "join") for join in joins: table = self.wrap_table(join.table) @@ -87,14 +86,14 @@ def _compile_joins(self, query, joins): clauses.append(self._compile_join_constraints(clause)) for binding in join.bindings: - query.add_binding(binding, 'join') + query.add_binding(binding, "join") # Once we have constructed the clauses, we'll need to take the boolean connector # off of the first clause as it obviously will not be required on that clause # because it leads the rest of the clauses, thus not requiring any boolean. clauses[0] = self._remove_leading_boolean(clauses[0]) - clauses = ' '.join(clauses) + clauses = " ".join(clauses) type = join.type @@ -102,212 +101,230 @@ def _compile_joins(self, query, joins): # build the final join statement SQL for the query and we can then return the # final clause back to the callers as a single, stringified join statement. - sql.append('%s JOIN %s ON %s' % (type.upper(), table, clauses)) + sql.append("%s JOIN %s ON %s" % (type.upper(), table, clauses)) - return ' '.join(sql) + return " ".join(sql) def _compile_join_constraints(self, clause): - first = self.wrap(clause['first']) + first = self.wrap(clause["first"]) - if clause['where']: + if clause["where"]: second = self.get_marker() else: - second = self.wrap(clause['second']) + second = self.wrap(clause["second"]) - return '%s %s %s %s' % (clause['boolean'].upper(), first, - clause['operator'], second) + return "%s %s %s %s" % ( + clause["boolean"].upper(), + first, + clause["operator"], + second, + ) def _compile_wheres(self, query, _=None): sql = [] if query.wheres is None: - return '' + return "" # Each type of where clauses has its own compiler function which is responsible # for actually creating the where clauses SQL. This helps keep the code nice # and maintainable since each clause has a very small method that it uses. for where in query.wheres: - method = '_where_%s' % where['type'] + method = "_where_%s" % where["type"] - sql.append('%s %s' % (where['boolean'].upper(), - getattr(self, method)(query, where))) + sql.append( + "%s %s" + % (where["boolean"].upper(), getattr(self, method)(query, where)) + ) # If we actually have some where clauses, we will strip off the first boolean # operator, which is added by the query builders for convenience so we can # avoid checking for the first clauses in each of the compilers methods. if len(sql) > 0: - sql = ' '.join(sql) + sql = " ".join(sql) - return 'WHERE %s' % re.sub('AND |OR ', '', sql, 1, re.I) + return "WHERE %s" % re.sub("AND |OR ", "", sql, 1, re.I) - return '' + return "" def _where_nested(self, query, where): - nested = where['query'] + nested = where["query"] - return '(%s)' % (self._compile_wheres(nested)[6:]) + return "(%s)" % (self._compile_wheres(nested)[6:]) def _where_sub(self, query, where): - select = self.compile_select(where['query']) + select = self.compile_select(where["query"]) - return '%s %s (%s)' % (self.wrap(where['column']), - where['operator'], select) + return "%s %s (%s)" % (self.wrap(where["column"]), where["operator"], select) def _where_basic(self, query, where): - value = self.parameter(where['value']) + value = self.parameter(where["value"]) - return '%s %s %s' % (self.wrap(where['column']), - where['operator'], value) + return "%s %s %s" % (self.wrap(where["column"]), where["operator"], value) def _where_between(self, query, where): - if where['not']: - between = 'NOT BETWEEN' + if where["not"]: + between = "NOT BETWEEN" else: - between = 'BETWEEN' + between = "BETWEEN" - return '%s %s %s AND %s' % (self.wrap(where['column']), between, - self.get_marker(), self.get_marker()) + return "%s %s %s AND %s" % ( + self.wrap(where["column"]), + between, + self.get_marker(), + self.get_marker(), + ) def _where_exists(self, query, where): - return 'EXISTS (%s)' % self.compile_select(where['query']) + return "EXISTS (%s)" % self.compile_select(where["query"]) def _where_not_exists(self, query, where): - return 'NOT EXISTS (%s)' % self.compile_select(where['query']) + return "NOT EXISTS (%s)" % self.compile_select(where["query"]) def _where_in(self, query, where): - if not where['values']: - return '0 = 1' + if not where["values"]: + return "0 = 1" - values = self.parameterize(where['values']) + values = self.parameterize(where["values"]) - return '%s IN (%s)' % (self.wrap(where['column']), values) + return "%s IN (%s)" % (self.wrap(where["column"]), values) def _where_not_in(self, query, where): - if not where['values']: - return '1 = 1' + if not where["values"]: + return "1 = 1" - values = self.parameterize(where['values']) + values = self.parameterize(where["values"]) - return '%s NOT IN (%s)' % (self.wrap(where['column']), values) + return "%s NOT IN (%s)" % (self.wrap(where["column"]), values) def _where_in_sub(self, query, where): - select = self.compile_select(where['query']) + select = self.compile_select(where["query"]) - return '%s IN (%s)' % (self.wrap(where['column']), select) + return "%s IN (%s)" % (self.wrap(where["column"]), select) def _where_not_in_sub(self, query, where): - select = self.compile_select(where['query']) + select = self.compile_select(where["query"]) - return '%s NOT IN (%s)' % (self.wrap(where['column']), select) + return "%s NOT IN (%s)" % (self.wrap(where["column"]), select) def _where_null(self, query, where): - return '%s IS NULL' % self.wrap(where['column']) + return "%s IS NULL" % self.wrap(where["column"]) def _where_not_null(self, query, where): - return '%s IS NOT NULL' % self.wrap(where['column']) + return "%s IS NOT NULL" % self.wrap(where["column"]) def _where_date(self, query, where): - return self._date_based_where('date', query, where) + return self._date_based_where("date", query, where) def _where_day(self, query, where): - return self._date_based_where('day', query, where) + return self._date_based_where("day", query, where) def _where_month(self, query, where): - return self._date_based_where('month', query, where) + return self._date_based_where("month", query, where) def _where_year(self, query, where): - return self._date_based_where('year', query, where) + return self._date_based_where("year", query, where) def _date_based_where(self, type, query, where): - value = self.parameter(where['value']) + value = self.parameter(where["value"]) - return '%s(%s) %s %s' % (type.upper(), self.wrap(where['column']), - where['operator'], value) + return "%s(%s) %s %s" % ( + type.upper(), + self.wrap(where["column"]), + where["operator"], + value, + ) def _where_raw(self, query, where): - return re.sub('( and | or )', - lambda m: m.group(1).upper(), - where['sql'], - re.I) + return re.sub("( and | or )", lambda m: m.group(1).upper(), where["sql"], re.I) def _compile_groups(self, query, groups): if not groups: - return '' + return "" - return 'GROUP BY %s' % self.columnize(groups) + return "GROUP BY %s" % self.columnize(groups) def _compile_havings(self, query, havings): if not havings: - return '' + return "" - sql = ' '.join(map(self._compile_having, havings)) + sql = " ".join(map(self._compile_having, havings)) - return 'HAVING %s' % re.sub('and |or ', '', sql, 1, re.I) + return "HAVING %s" % re.sub("and |or ", "", sql, 1, re.I) def _compile_having(self, having): # If the having clause is "raw", we can just return the clause straight away # without doing any more processing on it. Otherwise, we will compile the # clause into SQL based on the components that make it up from builder. - if having['type'] == 'raw': - return '%s %s' % (having['boolean'].upper(), having['sql']) + if having["type"] == "raw": + return "%s %s" % (having["boolean"].upper(), having["sql"]) return self._compile_basic_having(having) def _compile_basic_having(self, having): - column = self.wrap(having['column']) + column = self.wrap(having["column"]) - parameter = self.parameter(having['value']) + parameter = self.parameter(having["value"]) - return '%s %s %s %s' % (having['boolean'].upper(), column, - having['operator'], parameter) + return "%s %s %s %s" % ( + having["boolean"].upper(), + column, + having["operator"], + parameter, + ) def _compile_orders(self, query, orders): if not orders: - return '' + return "" compiled = [] for order in orders: - if order.get('sql'): - compiled.append(re.sub('( desc| asc)( |$)', - lambda m: '%s%s' % (m.group(1).upper(), m.group(2)), - order['sql'], - re.I)) + if order.get("sql"): + compiled.append( + re.sub( + "( desc| asc)( |$)", + lambda m: "%s%s" % (m.group(1).upper(), m.group(2)), + order["sql"], + re.I, + ) + ) else: - compiled.append('%s %s' % (self.wrap(order['column']), - order['direction'].upper())) + compiled.append( + "%s %s" % (self.wrap(order["column"]), order["direction"].upper()) + ) - return 'ORDER BY %s' % ', '.join(compiled) + return "ORDER BY %s" % ", ".join(compiled) def _compile_limit(self, query, limit): - return 'LIMIT %s' % int(limit) + return "LIMIT %s" % int(limit) def _compile_offset(self, query, offset): - return 'OFFSET %s' % int(offset) + return "OFFSET %s" % int(offset) def _compile_unions(self, query, _=None): - sql = '' + sql = "" for union in query.unions: sql += self._compile_union(union) if query.union_orders: - sql += ' %s' % self._compile_orders(query, query.union_orders) + sql += " %s" % self._compile_orders(query, query.union_orders) if query.union_limit: - sql += ' %s' % self._compile_limit(query, query.union_limit) + sql += " %s" % self._compile_limit(query, query.union_limit) if query.union_offset: - sql += ' %s' % self._compile_offset(query, query.union_offset) + sql += " %s" % self._compile_offset(query, query.union_offset) return sql.lstrip() def _compile_union(self, union): - if union['all']: - joiner = ' UNION ALL ' + if union["all"]: + joiner = " UNION ALL " else: - joiner = ' UNION ' + joiner = " UNION " - return '%s%s' % (joiner, union['query'].to_sql()) + return "%s%s" % (joiner, union["query"].to_sql()) def compile_insert(self, query, values): """ @@ -337,11 +354,11 @@ def compile_insert(self, query, values): # bindings so we can just go off the first list of values in this array. parameters = self.parameterize(values[0].values()) - value = ['(%s)' % parameters] * len(values) + value = ["(%s)" % parameters] * len(values) - parameters = ', '.join(value) + parameters = ", ".join(value) - return 'INSERT INTO %s (%s) VALUES %s' % (table, columns, parameters) + return "INSERT INTO %s (%s) VALUES %s" % (table, columns, parameters) def compile_insert_get_id(self, query, values, sequence): return self.compile_insert(query, values) @@ -355,24 +372,24 @@ def compile_update(self, query, values): columns = [] for key, value in values.items(): - columns.append('%s = %s' % (self.wrap(key), self.parameter(value))) + columns.append("%s = %s" % (self.wrap(key), self.parameter(value))) - columns = ', '.join(columns) + columns = ", ".join(columns) # If the query has any "join" clauses, we will setup the joins on the builder # and compile them so we can attach them to this update, as update queries # can get join statements to attach to other tables when they're needed. if query.joins: - joins = ' %s' % self._compile_joins(query, query.joins) + joins = " %s" % self._compile_joins(query, query.joins) else: - joins = '' + joins = "" # Of course, update queries may also be constrained by where clauses so we'll # need to compile the where clauses and attach it to the query so only the # intended records are updated by the SQL statements we generate to run. where = self._compile_wheres(query) - return ('UPDATE %s%s SET %s %s' % (table, joins, columns, where)).strip() + return ("UPDATE %s%s SET %s %s" % (table, joins, columns, where)).strip() def compile_delete(self, query): table = self.wrap_table(query.from__) @@ -380,20 +397,18 @@ def compile_delete(self, query): if isinstance(query.wheres, list): where = self._compile_wheres(query) else: - where = '' + where = "" - return ('DELETE FROM %s %s' % (table, where)).strip() + return ("DELETE FROM %s %s" % (table, where)).strip() def compile_truncate(self, query): - return { - 'TRUNCATE %s' % self.wrap_table(query.from__): [] - } + return {"TRUNCATE %s" % self.wrap_table(query.from__): []} def _compile_lock(self, query, value): if isinstance(value, basestring): return value else: - return '' + return "" def _concatenate(self, segments): parts = [] @@ -403,9 +418,7 @@ def _concatenate(self, segments): if value: parts.append(value) - return ' '.join(parts) + return " ".join(parts) def _remove_leading_boolean(self, value): - return re.sub('and | or ', '', value, 1, re.I) - - + return re.sub("and | or ", "", value, 1, re.I) diff --git a/orator/query/grammars/mysql_grammar.py b/orator/query/grammars/mysql_grammar.py index 39139039..909c5993 100644 --- a/orator/query/grammars/mysql_grammar.py +++ b/orator/query/grammars/mysql_grammar.py @@ -7,20 +7,20 @@ class MySQLQueryGrammar(QueryGrammar): _select_components = [ - 'aggregate_', - 'columns', - 'from__', - 'joins', - 'wheres', - 'groups', - 'havings', - 'orders', - 'limit_', - 'offset_', - 'lock_' + "aggregate_", + "columns", + "from__", + "joins", + "wheres", + "groups", + "havings", + "orders", + "limit_", + "offset_", + "lock_", ] - marker = '%s' + marker = "%s" def compile_select(self, query): """ @@ -35,7 +35,7 @@ def compile_select(self, query): sql = super(MySQLQueryGrammar, self).compile_select(query) if query.unions: - sql = '(%s) %s' % (sql, self._compile_unions(query)) + sql = "(%s) %s" % (sql, self._compile_unions(query)) return sql @@ -49,12 +49,12 @@ def _compile_union(self, union): :return: The compiled union statement :rtype: str """ - if union['all']: - joiner = ' UNION ALL ' + if union["all"]: + joiner = " UNION ALL " else: - joiner = ' UNION ' + joiner = " UNION " - return '%s(%s)' % (joiner, union['query'].to_sql()) + return "%s(%s)" % (joiner, union["query"].to_sql()) def _compile_lock(self, query, value): """ @@ -73,9 +73,9 @@ def _compile_lock(self, query, value): return value if value is True: - return 'FOR UPDATE' + return "FOR UPDATE" elif value is False: - return 'LOCK IN SHARE MODE' + return "LOCK IN SHARE MODE" def compile_update(self, query, values): """ @@ -93,10 +93,10 @@ def compile_update(self, query, values): sql = super(MySQLQueryGrammar, self).compile_update(query, values) if query.orders: - sql += ' %s' % self._compile_orders(query, query.orders) + sql += " %s" % self._compile_orders(query, query.orders) if query.limit_: - sql += ' %s' % self._compile_limit(query, query.limit_) + sql += " %s" % self._compile_limit(query, query.limit_) return sql.rstrip() @@ -115,22 +115,22 @@ def compile_delete(self, query): if isinstance(query.wheres, list): wheres = self._compile_wheres(query) else: - wheres = '' + wheres = "" if query.joins: - joins = ' %s' % self._compile_joins(query, query.joins) + joins = " %s" % self._compile_joins(query, query.joins) - sql = 'DELETE %s FROM %s%s %s' % (table, table, joins, wheres) + sql = "DELETE %s FROM %s%s %s" % (table, table, joins, wheres) else: - sql = 'DELETE FROM %s %s' % (table, wheres) + sql = "DELETE FROM %s %s" % (table, wheres) sql = sql.strip() if query.orders: - sql += ' %s' % self._compile_orders(query, query.orders) + sql += " %s" % self._compile_orders(query, query.orders) if query.limit_: - sql += ' %s' % self._compile_limit(query, query.limit_) + sql += " %s" % self._compile_limit(query, query.limit_) return sql @@ -144,7 +144,7 @@ def _wrap_value(self, value): :return: The wrapped value :rtype: str """ - if value == '*': + if value == "*": return value - return '`%s`' % value.replace('`', '``') + return "`%s`" % value.replace("`", "``") diff --git a/orator/query/grammars/postgres_grammar.py b/orator/query/grammars/postgres_grammar.py index cd33f9d4..90dc8d2a 100644 --- a/orator/query/grammars/postgres_grammar.py +++ b/orator/query/grammars/postgres_grammar.py @@ -7,12 +7,25 @@ class PostgresQueryGrammar(QueryGrammar): _operators = [ - '=', '<', '>', '<=', '>=', '<>', '!=', - 'like', 'not like', 'between', 'ilike', - '&', '|', '#', '<<', '>>' + "=", + "<", + ">", + "<=", + ">=", + "<>", + "!=", + "like", + "not like", + "between", + "ilike", + "&", + "|", + "#", + "<<", + ">>", ] - marker = '%s' + marker = "%s" def _compile_lock(self, query, value): """ @@ -31,9 +44,9 @@ def _compile_lock(self, query, value): return value if value: - return 'FOR UPDATE' + return "FOR UPDATE" - return 'FOR SHARE' + return "FOR SHARE" def compile_update(self, query, values): """ @@ -56,7 +69,7 @@ def compile_update(self, query, values): where = self._compile_update_wheres(query) - return ('UPDATE %s SET %s%s %s' % (table, columns, from_, where)).strip() + return ("UPDATE %s SET %s%s %s" % (table, columns, from_, where)).strip() def _compile_update_columns(self, values): """ @@ -71,9 +84,9 @@ def _compile_update_columns(self, values): columns = [] for key, value in values.items(): - columns.append('%s = %s' % (self.wrap(key), self.parameter(value))) + columns.append("%s = %s" % (self.wrap(key), self.parameter(value))) - return ', '.join(columns) + return ", ".join(columns) def _compile_update_from(self, query): """ @@ -86,7 +99,7 @@ def _compile_update_from(self, query): :rtype: str """ if not query.joins: - return '' + return "" froms = [] @@ -94,9 +107,9 @@ def _compile_update_from(self, query): froms.append(self.wrap_table(join.table)) if len(froms): - return ' FROM %s' % ', '.join(froms) + return " FROM %s" % ", ".join(froms) - return '' + return "" def _compile_update_wheres(self, query): """ @@ -116,9 +129,9 @@ def _compile_update_wheres(self, query): join_where = self._compile_update_join_wheres(query) if not base_where.strip(): - return 'WHERE %s' % self._remove_leading_boolean(join_where) + return "WHERE %s" % self._remove_leading_boolean(join_where) - return '%s %s' % (base_where, join_where) + return "%s %s" % (base_where, join_where) def _compile_update_join_wheres(self, query): """ @@ -136,7 +149,7 @@ def _compile_update_join_wheres(self, query): for clause in join.clauses: join_wheres.append(self._compile_join_constraints(clause)) - return ' '.join(join_wheres) + return " ".join(join_wheres) def compile_insert_get_id(self, query, values, sequence=None): """ @@ -155,10 +168,12 @@ def compile_insert_get_id(self, query, values, sequence=None): :rtype: str """ if sequence is None: - sequence = 'id' + sequence = "id" - return '%s RETURNING %s'\ - % (self.compile_insert(query, values), self.wrap(sequence)) + return "%s RETURNING %s" % ( + self.compile_insert(query, values), + self.wrap(sequence), + ) def compile_truncate(self, query): """ @@ -170,6 +185,4 @@ def compile_truncate(self, query): :return: The compiled statement :rtype: str """ - return { - 'TRUNCATE %s RESTART IDENTITY' % self.wrap_table(query.from__): {} - } + return {"TRUNCATE %s RESTART IDENTITY" % self.wrap_table(query.from__): {}} diff --git a/orator/query/grammars/sqlite_grammar.py b/orator/query/grammars/sqlite_grammar.py index fb9399be..f2bd623d 100644 --- a/orator/query/grammars/sqlite_grammar.py +++ b/orator/query/grammars/sqlite_grammar.py @@ -6,9 +6,21 @@ class SQLiteQueryGrammar(QueryGrammar): _operators = [ - '=', '<', '>', '<=', '>=', '<>', '!=', - 'like', 'not like', 'between', 'ilike', - '&', '|', '<<', '>>', + "=", + "<", + ">", + "<=", + ">=", + "<>", + "!=", + "like", + "not like", + "between", + "ilike", + "&", + "|", + "<<", + ">>", ] def compile_insert(self, query, values): @@ -41,12 +53,15 @@ def compile_insert(self, query, values): # unions joining them together. So we'll build out this list of columns and # then join them all together with select unions to complete the queries. for column in values[0].keys(): - columns.append('%s AS %s' % (self.get_marker(), self.wrap(column))) + columns.append("%s AS %s" % (self.get_marker(), self.wrap(column))) - columns = [', '.join(columns)] * len(values) + columns = [", ".join(columns)] * len(values) - return 'INSERT INTO %s (%s) SELECT %s'\ - % (table, names, ' UNION ALL SELECT '.join(columns)) + return "INSERT INTO %s (%s) SELECT %s" % ( + table, + names, + " UNION ALL SELECT ".join(columns), + ) def compile_truncate(self, query): """ @@ -59,13 +74,29 @@ def compile_truncate(self, query): :rtype: str """ sql = { - 'DELETE FROM sqlite_sequence WHERE name = %s' % self.get_marker(): [query.from__] + "DELETE FROM sqlite_sequence WHERE name = %s" + % self.get_marker(): [query.from__] } - sql['DELETE FROM %s' % self.wrap_table(query.from__)] = [] + sql["DELETE FROM %s" % self.wrap_table(query.from__)] = [] return sql + def _where_date(self, query, where): + """ + Compile a "where date" clause + + :param query: A QueryBuilder instance + :type query: QueryBuilder + + :param where: The condition + :type where: dict + + :return: The compiled clause + :rtype: str + """ + return self._date_based_where("%Y-%m-%d", query, where) + def _where_day(self, query, where): """ Compile a "where day" clause @@ -79,7 +110,7 @@ def _where_day(self, query, where): :return: The compiled clause :rtype: str """ - return self._date_based_where('%d', query, where) + return self._date_based_where("%d", query, where) def _where_month(self, query, where): """ @@ -94,7 +125,7 @@ def _where_month(self, query, where): :return: The compiled clause :rtype: str """ - return self._date_based_where('%m', query, where) + return self._date_based_where("%m", query, where) def _where_year(self, query, where): """ @@ -109,7 +140,7 @@ def _where_year(self, query, where): :return: The compiled clause :rtype: str """ - return self._date_based_where('%Y', query, where) + return self._date_based_where("%Y", query, where) def _date_based_where(self, type, query, where): """ @@ -127,9 +158,12 @@ def _date_based_where(self, type, query, where): :return: The compiled clause :rtype: str """ - value = str(where['value']).zfill(2) + value = str(where["value"]).zfill(2) value = self.parameter(value) - return 'strftime(\'%s\', %s) %s %s'\ - % (type, self.wrap(where['column']), - where['operator'], value) + return "strftime('%s', %s) %s %s" % ( + type, + self.wrap(where["column"]), + where["operator"], + value, + ) diff --git a/orator/query/join_clause.py b/orator/query/join_clause.py index adccc305..f3c2be0a 100644 --- a/orator/query/join_clause.py +++ b/orator/query/join_clause.py @@ -4,22 +4,23 @@ class JoinClause(object): - - def __init__(self, table, type='inner'): + def __init__(self, table, type="inner"): self.type = type self.table = table self.clauses = [] self.bindings = [] - def on(self, first, operator, second, boolean='and', where=False): - self.clauses.append({ - 'first': first, - 'operator': operator, - 'second': second, - 'boolean': boolean, - 'where': where - }) + def on(self, first, operator, second, boolean="and", where=False): + self.clauses.append( + { + "first": first, + "operator": operator, + "second": second, + "boolean": boolean, + "where": where, + } + ) if where: self.bindings.append(second) @@ -27,22 +28,22 @@ def on(self, first, operator, second, boolean='and', where=False): return self def or_on(self, first, operator, second): - return self.on(first, operator, second, 'or') + return self.on(first, operator, second, "or") - def where(self, first, operator, second, boolean='and'): + def where(self, first, operator, second, boolean="and"): return self.on(first, operator, second, boolean, True) def or_where(self, first, operator, second): - return self.where(first, operator, second, 'or') + return self.where(first, operator, second, "or") - def where_null(self, column, boolean='and'): - return self.on(column, 'IS', QueryExpression('NULL'), boolean, False) + def where_null(self, column, boolean="and"): + return self.on(column, "IS", QueryExpression("NULL"), boolean, False) def or_where_null(self, column): - return self.where_null(column, 'or') + return self.where_null(column, "or") - def where_not_null(self, column, boolean='and'): - return self.on(column, 'IS', QueryExpression('NOT NULL'), boolean, False) + def where_not_null(self, column, boolean="and"): + return self.on(column, "IS", QueryExpression("NOT NULL"), boolean, False) def or_where_not_null(self, column): - return self.where_not_null(column, 'or') + return self.where_not_null(column, "or") diff --git a/orator/query/processors/mysql_processor.py b/orator/query/processors/mysql_processor.py index 289df386..3a7d5596 100644 --- a/orator/query/processors/mysql_processor.py +++ b/orator/query/processors/mysql_processor.py @@ -4,7 +4,6 @@ class MySQLQueryProcessor(QueryProcessor): - def process_insert_get_id(self, query, sql, values, sequence=None): """ Process an "insert get ID" query. @@ -29,18 +28,18 @@ def process_insert_get_id(self, query, sql, values, sequence=None): query.get_connection().insert(sql, values) cursor = query.get_connection().get_cursor() - if hasattr(cursor, 'lastrowid'): + if hasattr(cursor, "lastrowid"): id = cursor.lastrowid else: - id = query.get_connection().statement('SELECT LAST_INSERT_ID()') + id = query.get_connection().statement("SELECT LAST_INSERT_ID()") else: query.get_connection().insert(sql, values) cursor = query.get_connection().get_cursor() - if hasattr(cursor, 'lastrowid'): + if hasattr(cursor, "lastrowid"): id = cursor.lastrowid else: - id = query.get_connection().statement('SELECT LAST_INSERT_ID()') + id = query.get_connection().statement("SELECT LAST_INSERT_ID()") if isinstance(id, int): return id @@ -60,4 +59,4 @@ def process_column_listing(self, results): :return: The processed results :return: list """ - return list(map(lambda x: x['column_name'], results)) + return list(map(lambda x: x["column_name"], results)) diff --git a/orator/query/processors/postgres_processor.py b/orator/query/processors/postgres_processor.py index 2b60d340..a5b88bb1 100644 --- a/orator/query/processors/postgres_processor.py +++ b/orator/query/processors/postgres_processor.py @@ -4,7 +4,6 @@ class PostgresQueryProcessor(QueryProcessor): - def process_insert_get_id(self, query, sql, values, sequence=None): """ Process an "insert get ID" query. @@ -46,4 +45,4 @@ def process_column_listing(self, results): :return: The processed results :return: list """ - return list(map(lambda x: x['column_name'], results)) + return list(map(lambda x: x["column_name"], results)) diff --git a/orator/query/processors/processor.py b/orator/query/processors/processor.py index 08812dd8..61909456 100644 --- a/orator/query/processors/processor.py +++ b/orator/query/processors/processor.py @@ -2,7 +2,6 @@ class QueryProcessor(object): - def process_select(self, query, results): """ Process the results of a "select" query diff --git a/orator/query/processors/sqlite_processor.py b/orator/query/processors/sqlite_processor.py index 77cbf2a8..55e6b945 100644 --- a/orator/query/processors/sqlite_processor.py +++ b/orator/query/processors/sqlite_processor.py @@ -4,7 +4,6 @@ class SQLiteQueryProcessor(QueryProcessor): - def process_column_listing(self, results): """ Process the results of a column listing query @@ -15,4 +14,4 @@ def process_column_listing(self, results): :return: The processed results :return: list """ - return list(map(lambda x: x['name'], results)) + return list(map(lambda x: x["name"], results)) diff --git a/orator/schema/blueprint.py b/orator/schema/blueprint.py index ea69a704..d32aadd6 100644 --- a/orator/schema/blueprint.py +++ b/orator/schema/blueprint.py @@ -4,7 +4,6 @@ class Blueprint(object): - def __init__(self, table): """ :param table: The table to operate on @@ -47,7 +46,7 @@ def to_sql(self, connection, grammar): statements = [] for command in self._commands: - method = 'compile_%s' % command.name + method = "compile_%s" % command.name if hasattr(grammar, method): sql = getattr(grammar, method)(self, command, connection) @@ -64,10 +63,10 @@ def _add_implied_commands(self): Add the commands that are implied by the blueprint. """ if len(self.get_added_columns()) and not self._creating(): - self._commands.insert(0, self._create_command('add')) + self._commands.insert(0, self._create_command("add")) if len(self.get_changed_columns()) and not self._creating(): - self._commands.insert(0, self._create_command('change')) + self._commands.insert(0, self._create_command("change")) return self._add_fluent_indexes() @@ -76,7 +75,7 @@ def _add_fluent_indexes(self): Add the index commands fluently specified on columns: """ for column in self._columns: - for index in ['primary', 'unique', 'index']: + for index in ["primary", "unique", "index"]: column_index = column.get(index) if column_index is True: @@ -95,7 +94,7 @@ def _creating(self): :rtype: bool """ for command in self._commands: - if command.name == 'create': + if command.name == "create": return True return False @@ -106,7 +105,7 @@ def create(self): :rtype: Fluent """ - return self._add_command('create') + return self._add_command("create") def drop(self): """ @@ -114,7 +113,7 @@ def drop(self): :rtype: Fluent """ - self._add_command('drop') + self._add_command("drop") return self @@ -124,7 +123,7 @@ def drop_if_exists(self): :rtype: Fluent """ - return self._add_command('drop_if_exists') + return self._add_command("drop_if_exists") def drop_column(self, *columns): """ @@ -137,7 +136,7 @@ def drop_column(self, *columns): """ columns = list(columns) - return self._add_command('drop_column', columns=columns) + return self._add_command("drop_column", columns=columns) def rename_column(self, from_, to): """ @@ -150,7 +149,7 @@ def rename_column(self, from_, to): :rtype: Fluent """ - return self._add_command('rename_column', **{'from_': from_, 'to': to}) + return self._add_command("rename_column", **{"from_": from_, "to": to}) def drop_primary(self, index=None): """ @@ -161,7 +160,7 @@ def drop_primary(self, index=None): :rtype: dict """ - return self._drop_index_command('drop_primary', 'primary', index) + return self._drop_index_command("drop_primary", "primary", index) def drop_unique(self, index): """ @@ -172,7 +171,7 @@ def drop_unique(self, index): :rtype: Fluent """ - return self._drop_index_command('drop_unique', 'unique', index) + return self._drop_index_command("drop_unique", "unique", index) def drop_index(self, index): """ @@ -183,7 +182,7 @@ def drop_index(self, index): :rtype: Fluent """ - return self._drop_index_command('drop_index', 'index', index) + return self._drop_index_command("drop_index", "index", index) def drop_foreign(self, index): """ @@ -194,7 +193,7 @@ def drop_foreign(self, index): :rtype: dict """ - return self._drop_index_command('drop_foreign', 'foreign', index) + return self._drop_index_command("drop_foreign", "foreign", index) def drop_timestamps(self): """ @@ -202,7 +201,7 @@ def drop_timestamps(self): :rtype: Fluent """ - return self.drop_column('created_at', 'updated_at') + return self.drop_column("created_at", "updated_at") def drop_soft_deletes(self): """ @@ -210,7 +209,7 @@ def drop_soft_deletes(self): :rtype: Fluent """ - return self.drop_column('deleted_at') + return self.drop_column("deleted_at") def rename(self, to): """ @@ -221,7 +220,7 @@ def rename(self, to): :rtype: Fluent """ - return self._add_command('rename', to=to) + return self._add_command("rename", to=to) def primary(self, columns, name=None): """ @@ -235,7 +234,7 @@ def primary(self, columns, name=None): :rtype: Fluent """ - return self._index_command('primary', columns, name) + return self._index_command("primary", columns, name) def unique(self, columns, name=None): """ @@ -249,7 +248,7 @@ def unique(self, columns, name=None): :rtype: Fluent """ - return self._index_command('unique', columns, name) + return self._index_command("unique", columns, name) def index(self, columns, name=None): """ @@ -263,7 +262,7 @@ def index(self, columns, name=None): :rtype: Fluent """ - return self._index_command('index', columns, name) + return self._index_command("index", columns, name) def foreign(self, columns, name=None): """ @@ -277,7 +276,7 @@ def foreign(self, columns, name=None): :rtype: Fluent """ - return self._index_command('foreign', columns, name) + return self._index_command("foreign", columns, name) def increments(self, column): """ @@ -310,7 +309,7 @@ def char(self, column, length=255): :rtype: Fluent """ - return self._add_column('char', column, length=length) + return self._add_column("char", column, length=length) def string(self, column, length=255): """ @@ -321,7 +320,7 @@ def string(self, column, length=255): :rtype: Fluent """ - return self._add_column('string', column, length=length) + return self._add_column("string", column, length=length) def text(self, column): """ @@ -332,7 +331,7 @@ def text(self, column): :rtype: Fluent """ - return self._add_column('text', column) + return self._add_column("text", column) def medium_text(self, column): """ @@ -343,7 +342,7 @@ def medium_text(self, column): :rtype: Fluent """ - return self._add_column('medium_text', column) + return self._add_column("medium_text", column) def long_text(self, column): """ @@ -354,7 +353,7 @@ def long_text(self, column): :rtype: Fluent """ - return self._add_column('long_text', column) + return self._add_column("long_text", column) def integer(self, column, auto_increment=False, unsigned=False): """ @@ -369,9 +368,9 @@ def integer(self, column, auto_increment=False, unsigned=False): :rtype: Fluent """ - return self._add_column('integer', column, - auto_increment=auto_increment, - unsigned=unsigned) + return self._add_column( + "integer", column, auto_increment=auto_increment, unsigned=unsigned + ) def big_integer(self, column, auto_increment=False, unsigned=False): """ @@ -386,9 +385,9 @@ def big_integer(self, column, auto_increment=False, unsigned=False): :rtype: Fluent """ - return self._add_column('big_integer', column, - auto_increment=auto_increment, - unsigned=unsigned) + return self._add_column( + "big_integer", column, auto_increment=auto_increment, unsigned=unsigned + ) def medium_integer(self, column, auto_increment=False, unsigned=False): """ @@ -403,9 +402,9 @@ def medium_integer(self, column, auto_increment=False, unsigned=False): :rtype: Fluent """ - return self._add_column('medium_integer', column, - auto_increment=auto_increment, - unsigned=unsigned) + return self._add_column( + "medium_integer", column, auto_increment=auto_increment, unsigned=unsigned + ) def tiny_integer(self, column, auto_increment=False, unsigned=False): """ @@ -420,9 +419,9 @@ def tiny_integer(self, column, auto_increment=False, unsigned=False): :rtype: Fluent """ - return self._add_column('tiny_integer', column, - auto_increment=auto_increment, - unsigned=unsigned) + return self._add_column( + "tiny_integer", column, auto_increment=auto_increment, unsigned=unsigned + ) def small_integer(self, column, auto_increment=False, unsigned=False): """ @@ -437,13 +436,13 @@ def small_integer(self, column, auto_increment=False, unsigned=False): :rtype: Fluent """ - return self._add_column('small_integer', column, - auto_increment=auto_increment, - unsigned=unsigned) + return self._add_column( + "small_integer", column, auto_increment=auto_increment, unsigned=unsigned + ) def unsigned_integer(self, column, auto_increment=False): """ - Create a new unisgned integer column on the table. + Create a new unsigned integer column on the table. :param column: The column :type column: str @@ -480,7 +479,7 @@ def float(self, column, total=8, places=2): :rtype: Fluent """ - return self._add_column('float', column, total=total, places=places) + return self._add_column("float", column, total=total, places=places) def double(self, column, total=None, places=None): """ @@ -495,7 +494,7 @@ def double(self, column, total=None, places=None): :rtype: Fluent """ - return self._add_column('double', column, total=total, places=places) + return self._add_column("double", column, total=total, places=places) def decimal(self, column, total=8, places=2): """ @@ -510,7 +509,7 @@ def decimal(self, column, total=8, places=2): :rtype: Fluent """ - return self._add_column('decimal', column, total=total, places=places) + return self._add_column("decimal", column, total=total, places=places) def boolean(self, column): """ @@ -521,8 +520,8 @@ def boolean(self, column): :rtype: Fluent """ - return self._add_column('boolean', column) - + return self._add_column("boolean", column) + def enum(self, column, allowed): """ Create a new enum column on the table. @@ -534,7 +533,7 @@ def enum(self, column, allowed): :rtype: Fluent """ - return self._add_column('enum', column, allowed=allowed) + return self._add_column("enum", column, allowed=allowed) def json(self, column): """ @@ -545,7 +544,7 @@ def json(self, column): :rtype: Fluent """ - return self._add_column('json', column) + return self._add_column("json", column) def date(self, column): """ @@ -556,7 +555,7 @@ def date(self, column): :rtype: Fluent """ - return self._add_column('date', column) + return self._add_column("date", column) def datetime(self, column): """ @@ -567,7 +566,7 @@ def datetime(self, column): :rtype: Fluent """ - return self._add_column('datetime', column) + return self._add_column("datetime", column) def time(self, column): """ @@ -578,7 +577,7 @@ def time(self, column): :rtype: Fluent """ - return self._add_column('time', column) + return self._add_column("time", column) def timestamp(self, column): """ @@ -589,7 +588,7 @@ def timestamp(self, column): :rtype: Fluent """ - return self._add_column('timestamp', column) + return self._add_column("timestamp", column) def nullable_timestamps(self): """ @@ -597,8 +596,8 @@ def nullable_timestamps(self): :rtype: Fluent """ - self.timestamp('created_at').nullable() - self.timestamp('updated_at').nullable() + self.timestamp("created_at").nullable() + self.timestamp("updated_at").nullable() def timestamps(self, use_current=True): """ @@ -607,11 +606,11 @@ def timestamps(self, use_current=True): :rtype: Fluent """ if use_current: - self.timestamp('created_at').use_current() - self.timestamp('updated_at').use_current() + self.timestamp("created_at").use_current() + self.timestamp("updated_at").use_current() else: - self.timestamp('created_at') - self.timestamp('updated_at') + self.timestamp("created_at") + self.timestamp("updated_at") def soft_deletes(self): """ @@ -619,7 +618,7 @@ def soft_deletes(self): :rtype: Fluent """ - return self.timestamp('deleted_at').nullable() + return self.timestamp("deleted_at").nullable() def binary(self, column): """ @@ -630,7 +629,7 @@ def binary(self, column): :rtype: Fluent """ - return self._add_column('binary', column) + return self._add_column("binary", column) def morphs(self, name, index_name=None): """ @@ -640,9 +639,9 @@ def morphs(self, name, index_name=None): :type index_name: str """ - self.unsigned_integer('%s_id' % name) - self.string('%s_type' % name) - self.index(['%s_id' % name, '%s_type' % name], index_name) + self.unsigned_integer("%s_id" % name) + self.string("%s_type" % name) + self.index(["%s_id" % name, "%s_type" % name], index_name) def _drop_index_command(self, command, type, index): """ @@ -695,9 +694,13 @@ def _create_index_name(self, type, columns): if not isinstance(columns, list): columns = [columns] - index = '%s_%s_%s' % (self._table, '_'.join([str(column) for column in columns]), type) + index = "%s_%s_%s" % ( + self._table, + "_".join([str(column) for column in columns]), + type, + ) - return index.lower().replace('-', '_').replace('.', '_') + return index.lower().replace("-", "_").replace(".", "_") def _add_column(self, type, name, **parameters): """ @@ -714,10 +717,7 @@ def _add_column(self, type, name, **parameters): :rtype: Fluent """ - parameters.update({ - 'type': type, - 'name': name - }) + parameters.update({"type": type, "name": name}) column = Fluent(**parameters) self._columns.append(column) @@ -766,7 +766,7 @@ def _create_command(self, name, **parameters): :rtype: Fluent """ - parameters.update({'name': name}) + parameters.update({"name": name}) return Fluent(**parameters) @@ -780,7 +780,7 @@ def get_commands(self): return self._commands def get_added_columns(self): - return list(filter(lambda column: not column.get('change'), self._columns)) + return list(filter(lambda column: not column.get("change"), self._columns)) def get_changed_columns(self): - return list(filter(lambda column: column.get('change'), self._columns)) + return list(filter(lambda column: column.get("change"), self._columns)) diff --git a/orator/schema/builder.py b/orator/schema/builder.py index 4935b3ef..9e4fb466 100644 --- a/orator/schema/builder.py +++ b/orator/schema/builder.py @@ -5,7 +5,6 @@ class SchemaBuilder(object): - def __init__(self, connection): """ :param connection: The schema connection diff --git a/orator/schema/grammars/grammar.py b/orator/schema/grammars/grammar.py index d9d13ca5..71c759fd 100644 --- a/orator/schema/grammars/grammar.py +++ b/orator/schema/grammars/grammar.py @@ -10,7 +10,6 @@ class SchemaGrammar(Grammar): - def __init__(self, connection): super(SchemaGrammar, self).__init__(marker=connection.get_marker()) @@ -93,19 +92,21 @@ def compile_foreign(self, blueprint, command, _): columns = self.columnize(command.columns) - on_columns = self.columnize(command.references - if isinstance(command.references, list) - else [command.references]) + on_columns = self.columnize( + command.references + if isinstance(command.references, list) + else [command.references] + ) - sql = 'ALTER TABLE %s ADD CONSTRAINT %s ' % (table, command.index) + sql = "ALTER TABLE %s ADD CONSTRAINT %s " % (table, command.index) - sql += 'FOREIGN KEY (%s) REFERENCES %s (%s)' % (columns, on, on_columns) + sql += "FOREIGN KEY (%s) REFERENCES %s (%s)" % (columns, on, on_columns) - if command.get('on_delete'): - sql += ' ON DELETE %s' % command.on_delete + if command.get("on_delete"): + sql += " ON DELETE %s" % command.on_delete - if command.get('on_update'): - sql += ' ON UPDATE %s' % command.on_update + if command.get("on_update"): + sql += " ON UPDATE %s" % command.on_update return sql @@ -121,7 +122,7 @@ def _get_columns(self, blueprint): columns = [] for column in blueprint.get_added_columns(): - sql = self.wrap(column) + ' ' + self._get_type(column) + sql = self.wrap(column) + " " + self._get_type(column) columns.append(self._add_modifiers(sql, blueprint, column)) @@ -132,7 +133,7 @@ def _add_modifiers(self, sql, blueprint, column): Add the column modifiers to the deifinition """ for modifier in self._modifiers: - method = '_modify_%s' % modifier + method = "_modify_%s" % modifier if hasattr(self, method): sql += getattr(self, method)(blueprint, column) @@ -163,13 +164,13 @@ def _get_type(self, column): :rtype sql """ - return getattr(self, '_type_%s' % column.type)(column) + return getattr(self, "_type_%s" % column.type)(column) def prefix_list(self, prefix, values): """ Add a prefix to a list of values. """ - return list(map(lambda value: prefix + ' ' + value, values)) + return list(map(lambda value: prefix + " " + value, values)) def wrap_table(self, table): if isinstance(table, Blueprint): @@ -245,9 +246,13 @@ def _get_changed_diff(self, blueprint, schema): :rtype: orator.dbal.TableDiff """ - table = schema.list_table_details(self.get_table_prefix() + blueprint.get_table()) + table = schema.list_table_details( + self.get_table_prefix() + blueprint.get_table() + ) - return Comparator().diff_table(table, self._get_table_with_column_changes(blueprint, table)) + return Comparator().diff_table( + table, self._get_table_with_column_changes(blueprint, table) + ) def _get_table_with_column_changes(self, blueprint, table): """ @@ -269,7 +274,7 @@ def _get_table_with_column_changes(self, blueprint, table): option = self._map_fluent_option(key) if option is not None: - method = 'set_%s' % option + method = "set_%s" % option if hasattr(column, method): getattr(column, method)(self._map_fluent_value(option, value)) @@ -293,13 +298,13 @@ def _get_column_change_options(self, fluent): Get the column change options. """ options = { - 'name': fluent.name, - 'type': self._get_dbal_column_type(fluent.type), - 'default': fluent.get('default') + "name": fluent.name, + "type": self._get_dbal_column_type(fluent.type), + "default": fluent.get("default"), } - if fluent.type in ['string']: - options['length'] = fluent.length + if fluent.type in ["string"]: + options["length"] = fluent.length return options @@ -314,29 +319,29 @@ def _get_dbal_column_type(self, type_): """ type_ = type_.lower() - if type_ == 'big_integer': - type_ = 'bigint' - elif type == 'small_integer': - type_ = 'smallint' - elif type_ in ['medium_text', 'long_text']: - type_ = 'text' + if type_ == "big_integer": + type_ = "bigint" + elif type == "small_integer": + type_ = "smallint" + elif type_ in ["medium_text", "long_text"]: + type_ = "text" return type_ def _map_fluent_option(self, attribute): - if attribute in ['type', 'name']: + if attribute in ["type", "name"]: return - elif attribute == 'nullable': - return 'notnull' - elif attribute == 'total': - return 'precision' - elif attribute == 'places': - return 'scale' + elif attribute == "nullable": + return "notnull" + elif attribute == "total": + return "precision" + elif attribute == "places": + return "scale" else: return def _map_fluent_value(self, option, value): - if option == 'notnull': + if option == "notnull": return not value return value diff --git a/orator/schema/grammars/mysql_grammar.py b/orator/schema/grammars/mysql_grammar.py index da21286a..af00ee38 100644 --- a/orator/schema/grammars/mysql_grammar.py +++ b/orator/schema/grammars/mysql_grammar.py @@ -9,14 +9,25 @@ class MySQLSchemaGrammar(SchemaGrammar): _modifiers = [ - 'unsigned', 'charset', 'collate', 'nullable', - 'default', 'increment', 'comment', 'after' + "unsigned", + "charset", + "collate", + "nullable", + "default", + "increment", + "comment", + "after", ] - _serials = ['big_integer', 'integer', - 'medium_integer', 'small_integer', 'tiny_integer'] + _serials = [ + "big_integer", + "integer", + "medium_integer", + "small_integer", + "tiny_integer", + ] - marker = '%s' + marker = "%s" def compile_table_exists(self): """ @@ -24,32 +35,36 @@ def compile_table_exists(self): :rtype: str """ - return 'SELECT * ' \ - 'FROM information_schema.tables ' \ - 'WHERE table_schema = %(marker)s ' \ - 'AND table_name = %(marker)s' % {'marker': self.get_marker()} + return ( + "SELECT * " + "FROM information_schema.tables " + "WHERE table_schema = %(marker)s " + "AND table_name = %(marker)s" % {"marker": self.get_marker()} + ) def compile_column_exists(self): """ Compile the query to determine the list of columns. """ - return 'SELECT column_name ' \ - 'FROM information_schema.columns ' \ - 'WHERE table_schema = %(marker)s AND table_name = %(marker)s' \ - % {'marker': self.get_marker()} + return ( + "SELECT column_name " + "FROM information_schema.columns " + "WHERE table_schema = %(marker)s AND table_name = %(marker)s" + % {"marker": self.get_marker()} + ) def compile_create(self, blueprint, command, connection): """ Compile a create table command. """ - columns = ', '.join(self._get_columns(blueprint)) + columns = ", ".join(self._get_columns(blueprint)) - sql = 'CREATE TABLE %s (%s)' % (self.wrap_table(blueprint), columns) + sql = "CREATE TABLE %s (%s)" % (self.wrap_table(blueprint), columns) sql = self._compile_create_encoding(sql, connection, blueprint) if blueprint.engine: - sql += ' ENGINE = %s' % blueprint.engine + sql += " ENGINE = %s" % blueprint.engine return sql @@ -63,77 +78,76 @@ def _compile_create_encoding(self, sql, connection, blueprint): :rtype: str """ - charset = blueprint.charset or connection.get_config('charset') + charset = blueprint.charset or connection.get_config("charset") if charset: - sql += ' DEFAULT CHARACTER SET %s' % charset + sql += " DEFAULT CHARACTER SET %s" % charset - collation = blueprint.collation or connection.get_config('collation') + collation = blueprint.collation or connection.get_config("collation") if collation: - sql += ' COLLATE %s' % collation + sql += " COLLATE %s" % collation return sql def compile_add(self, blueprint, command, _): table = self.wrap_table(blueprint) - columns = self.prefix_list('ADD', self._get_columns(blueprint)) + columns = self.prefix_list("ADD", self._get_columns(blueprint)) - return 'ALTER TABLE %s %s' % (table, ', '.join(columns)) + return "ALTER TABLE %s %s" % (table, ", ".join(columns)) def compile_primary(self, blueprint, command, _): command.name = None - return self._compile_key(blueprint, command, 'PRIMARY KEY') + return self._compile_key(blueprint, command, "PRIMARY KEY") def compile_unique(self, blueprint, command, _): - return self._compile_key(blueprint, command, 'UNIQUE') + return self._compile_key(blueprint, command, "UNIQUE") def compile_index(self, blueprint, command, _): - return self._compile_key(blueprint, command, 'INDEX') + return self._compile_key(blueprint, command, "INDEX") def _compile_key(self, blueprint, command, type): columns = self.columnize(command.columns) table = self.wrap_table(blueprint) - return 'ALTER TABLE %s ADD %s %s(%s)' % (table, type, command.index, columns) + return "ALTER TABLE %s ADD %s %s(%s)" % (table, type, command.index, columns) def compile_drop(self, blueprint, command, _): - return 'DROP TABLE %s' % self.wrap_table(blueprint) + return "DROP TABLE %s" % self.wrap_table(blueprint) def compile_drop_if_exists(self, blueprint, command, _): - return 'DROP TABLE IF EXISTS %s' % self.wrap_table(blueprint) + return "DROP TABLE IF EXISTS %s" % self.wrap_table(blueprint) def compile_drop_column(self, blueprint, command, connection): - columns = self.prefix_list('DROP', self.wrap_list(command.columns)) + columns = self.prefix_list("DROP", self.wrap_list(command.columns)) table = self.wrap_table(blueprint) - return 'ALTER TABLE %s %s' % (table, ', '.join(columns)) + return "ALTER TABLE %s %s" % (table, ", ".join(columns)) def compile_drop_primary(self, blueprint, command, _): - return 'ALTER TABLE %s DROP PRIMARY KEY'\ - % self.wrap_table(blueprint) + return "ALTER TABLE %s DROP PRIMARY KEY" % self.wrap_table(blueprint) def compile_drop_unique(self, blueprint, command, _): table = self.wrap_table(blueprint) - return 'ALTER TABLE %s DROP INDEX %s' % (table, command.index) + return "ALTER TABLE %s DROP INDEX %s" % (table, command.index) def compile_drop_index(self, blueprint, command, _): table = self.wrap_table(blueprint) - return 'ALTER TABLE %s DROP INDEX %s' % (table, command.index) + return "ALTER TABLE %s DROP INDEX %s" % (table, command.index) def compile_drop_foreign(self, blueprint, command, _): table = self.wrap_table(blueprint) - return 'ALTER TABLE %s DROP FOREIGN KEY %s' % (table, command.index) + return "ALTER TABLE %s DROP FOREIGN KEY %s" % (table, command.index) def compile_rename(self, blueprint, command, _): from_ = self.wrap_table(blueprint) - return 'RENAME TABLE %s TO %s' % (from_, self.wrap_table(command.to)) + return "RENAME TABLE %s TO %s" % (from_, self.wrap_table(command.to)) def _type_char(self, column): return "CHAR(%s)" % column.length @@ -142,121 +156,131 @@ def _type_string(self, column): return "VARCHAR(%s)" % column.length def _type_text(self, column): - return 'TEXT' + return "TEXT" def _type_medium_text(self, column): - return 'MEDIUMTEXT' + return "MEDIUMTEXT" def _type_long_text(self, column): - return 'LONGTEXT' + return "LONGTEXT" def _type_integer(self, column): - return 'INT' + return "INT" def _type_big_integer(self, column): - return 'BIGINT' + return "BIGINT" def _type_medium_integer(self, column): - return 'MEDIUMINT' + return "MEDIUMINT" def _type_tiny_integer(self, column): - return 'TINYINT' + return "TINYINT" def _type_small_integer(self, column): - return 'SMALLINT' + return "SMALLINT" def _type_float(self, column): return self._type_double(column) def _type_double(self, column): if column.total and column.places: - return 'DOUBLE(%s, %s)' % (column.total, column.places) + return "DOUBLE(%s, %s)" % (column.total, column.places) - return 'DOUBLE' + return "DOUBLE" def _type_decimal(self, column): - return 'DECIMAL(%s, %s)' % (column.total, column.places) + return "DECIMAL(%s, %s)" % (column.total, column.places) def _type_boolean(self, column): - return 'TINYINT(1)' + return "TINYINT(1)" def _type_enum(self, column): - return 'ENUM(\'%s\')' % '\', \''.join(column.allowed) + return "ENUM('%s')" % "', '".join(column.allowed) def _type_json(self, column): if self.platform().has_native_json_type(): - return 'JSON' + return "JSON" - return 'TEXT' + return "TEXT" def _type_date(self, column): - return 'DATE' + return "DATE" def _type_datetime(self, column): - return 'DATETIME' + return "DATETIME" def _type_time(self, column): - return 'TIME' + return "TIME" def _type_timestamp(self, column): - if column.use_current: - if self.platform_version() >= (5, 6): - return 'TIMESTAMP DEFAULT CURRENT_TIMESTAMP' + platform_version = self.platform_version(3) + column_type = "TIMESTAMP" + + if platform_version >= (5, 6, 0): + if platform_version >= (5, 6, 4): + # Versions 5.6.4+ support fractional seconds + column_type = "TIMESTAMP(6)" + current = "CURRENT_TIMESTAMP(6)" else: - return 'TIMESTAMP DEFAULT 0' + current = "CURRENT_TIMESTAMP" + else: + current = "0" + + if column.use_current: + return "{} DEFAULT {}".format(column_type, current) - return 'TIMESTAMP' + return column_type def _type_binary(self, column): - return 'BLOB' + return "BLOB" def _modify_unsigned(self, blueprint, column): - if column.get('unsigned', False): - return ' UNSIGNED' + if column.get("unsigned", False): + return " UNSIGNED" - return '' + return "" def _modify_charset(self, blueprint, column): - if column.get('charset'): - return ' CHARACTER SET ' + column.charset + if column.get("charset"): + return " CHARACTER SET " + column.charset - return '' + return "" def _modify_collate(self, blueprint, column): - if column.get('collation'): - return ' COLLATE ' + column.collation + if column.get("collation"): + return " COLLATE " + column.collation - return '' + return "" def _modify_nullable(self, blueprint, column): - if column.get('nullable'): - return ' NULL' + if column.get("nullable"): + return " NULL" - return ' NOT NULL' + return " NOT NULL" def _modify_default(self, blueprint, column): - if column.get('default') is not None: - return ' DEFAULT %s' % self._get_default_value(column.default) + if column.get("default") is not None: + return " DEFAULT %s" % self._get_default_value(column.default) - return '' + return "" def _modify_increment(self, blueprint, column): if column.type in self._serials and column.auto_increment: - return ' AUTO_INCREMENT PRIMARY KEY' + return " AUTO_INCREMENT PRIMARY KEY" - return '' + return "" def _modify_after(self, blueprint, column): - if column.get('after') is not None: - return ' AFTER ' + self.wrap(column.after) + if column.get("after") is not None: + return " AFTER " + self.wrap(column.after) - return '' + return "" def _modify_comment(self, blueprint, column): - if column.get('comment') is not None: + if column.get("comment") is not None: return ' COMMENT "%s"' % column.comment - return '' + return "" def _get_column_change_options(self, fluent): """ @@ -264,15 +288,15 @@ def _get_column_change_options(self, fluent): """ options = super(MySQLSchemaGrammar, self)._get_column_change_options(fluent) - if fluent.type == 'enum': - options['extra'] = { - 'definition': '(\'{}\')'.format('\',\''.join(fluent.allowed)) + if fluent.type == "enum": + options["extra"] = { + "definition": "('{}')".format("','".join(fluent.allowed)) } return options def _wrap_value(self, value): - if value == '*': + if value == "*": return value - return '`%s`' % value.replace('`', '``') + return "`%s`" % value.replace("`", "``") diff --git a/orator/schema/grammars/postgres_grammar.py b/orator/schema/grammars/postgres_grammar.py index 1495eb1c..e1086f8b 100644 --- a/orator/schema/grammars/postgres_grammar.py +++ b/orator/schema/grammars/postgres_grammar.py @@ -8,12 +8,17 @@ class PostgresSchemaGrammar(SchemaGrammar): - _modifiers = ['increment', 'nullable', 'default'] + _modifiers = ["increment", "nullable", "default"] - _serials = ['big_integer', 'integer', - 'medium_integer', 'small_integer', 'tiny_integer'] + _serials = [ + "big_integer", + "integer", + "medium_integer", + "small_integer", + "tiny_integer", + ] - marker = '%s' + marker = "%s" def compile_rename_column(self, blueprint, command, connection): """ @@ -34,8 +39,11 @@ def compile_rename_column(self, blueprint, command, connection): column = self.wrap(command.from_) - return 'ALTER TABLE %s RENAME COLUMN %s TO %s'\ - % (table, column, self.wrap(command.to)) + return "ALTER TABLE %s RENAME COLUMN %s TO %s" % ( + table, + column, + self.wrap(command.to), + ) def compile_table_exists(self): """ @@ -43,91 +51,101 @@ def compile_table_exists(self): :rtype: str """ - return 'SELECT * ' \ - 'FROM information_schema.tables ' \ - 'WHERE table_name = \'%(marker)s\'' \ - % {'marker': self.get_marker()} + return ( + "SELECT * " + "FROM information_schema.tables " + "WHERE table_name = %(marker)s" % {"marker": self.get_marker()} + ) def compile_column_exists(self, table): """ Compile the query to determine the list of columns. """ - return 'SELECT column_name ' \ - 'FROM information_schema.columns ' \ - 'WHERE table_name = \'%s\'' % table + return ( + "SELECT column_name " + "FROM information_schema.columns " + "WHERE table_name = '%s'" % table + ) def compile_create(self, blueprint, command, _): """ Compile a create table command. """ - columns = ', '.join(self._get_columns(blueprint)) + columns = ", ".join(self._get_columns(blueprint)) - return 'CREATE TABLE %s (%s)' % (self.wrap_table(blueprint), columns) + return "CREATE TABLE %s (%s)" % (self.wrap_table(blueprint), columns) def compile_add(self, blueprint, command, _): table = self.wrap_table(blueprint) - columns = self.prefix_list('ADD COLUMN', self._get_columns(blueprint)) + columns = self.prefix_list("ADD COLUMN", self._get_columns(blueprint)) - return 'ALTER TABLE %s %s' % (table, ', '.join(columns)) + return "ALTER TABLE %s %s" % (table, ", ".join(columns)) def compile_primary(self, blueprint, command, _): columns = self.columnize(command.columns) - return 'ALTER TABLE %s ADD PRIMARY KEY (%s)'\ - % (self.wrap_table(blueprint), columns) + return "ALTER TABLE %s ADD PRIMARY KEY (%s)" % ( + self.wrap_table(blueprint), + columns, + ) def compile_unique(self, blueprint, command, _): columns = self.columnize(command.columns) table = self.wrap_table(blueprint) - return 'ALTER TABLE %s ADD CONSTRAINT %s UNIQUE (%s)'\ - % (table, command.index, columns) + return "ALTER TABLE %s ADD CONSTRAINT %s UNIQUE (%s)" % ( + table, + command.index, + columns, + ) def compile_index(self, blueprint, command, _): columns = self.columnize(command.columns) table = self.wrap_table(blueprint) - return 'CREATE INDEX %s ON %s (%s)' % (command.index, table, columns) + return "CREATE INDEX %s ON %s (%s)" % (command.index, table, columns) def compile_drop(self, blueprint, command, _): - return 'DROP TABLE %s' % self.wrap_table(blueprint) + return "DROP TABLE %s" % self.wrap_table(blueprint) def compile_drop_if_exists(self, blueprint, command, _): - return 'DROP TABLE IF EXISTS %s' % self.wrap_table(blueprint) + return "DROP TABLE IF EXISTS %s" % self.wrap_table(blueprint) def compile_drop_column(self, blueprint, command, connection): - columns = self.prefix_list('DROP COLUMN', self.wrap_list(command.columns)) + columns = self.prefix_list("DROP COLUMN", self.wrap_list(command.columns)) table = self.wrap_table(blueprint) - return 'ALTER TABLE %s %s' % (table, ', '.join(columns)) + return "ALTER TABLE %s %s" % (table, ", ".join(columns)) def compile_drop_primary(self, blueprint, command, _): table = blueprint.get_table() - return 'ALTER TABLE %s DROP CONSTRAINT %s_pkey'\ - % (self.wrap_table(blueprint), table) + return "ALTER TABLE %s DROP CONSTRAINT %s_pkey" % ( + self.wrap_table(blueprint), + table, + ) def compile_drop_unique(self, blueprint, command, _): table = self.wrap_table(blueprint) - return 'ALTER TABLE %s DROP CONSTRAINT %s' % (table, command.index) + return "ALTER TABLE %s DROP CONSTRAINT %s" % (table, command.index) def compile_drop_index(self, blueprint, command, _): - return 'DROP INDEX %s' % command.index + return "DROP INDEX %s" % command.index def compile_drop_foreign(self, blueprint, command, _): table = self.wrap_table(blueprint) - return 'ALTER TABLE %s DROP CONSTRAINT %s' % (table, command.index) + return "ALTER TABLE %s DROP CONSTRAINT %s" % (table, command.index) def compile_rename(self, blueprint, command, _): from_ = self.wrap_table(blueprint) - return 'ALTER TABLE %s RENAME TO %s' % (from_, self.wrap_table(command.to)) + return "ALTER TABLE %s RENAME TO %s" % (from_, self.wrap_table(command.to)) def _type_char(self, column): return "CHAR(%s)" % column.length @@ -136,84 +154,84 @@ def _type_string(self, column): return "VARCHAR(%s)" % column.length def _type_text(self, column): - return 'TEXT' + return "TEXT" def _type_medium_text(self, column): - return 'TEXT' + return "TEXT" def _type_long_text(self, column): - return 'TEXT' + return "TEXT" def _type_integer(self, column): - return 'SERIAL' if column.auto_increment else 'INTEGER' + return "SERIAL" if column.auto_increment else "INTEGER" def _type_big_integer(self, column): - return 'BIGSERIAL' if column.auto_increment else 'BIGINT' + return "BIGSERIAL" if column.auto_increment else "BIGINT" def _type_medium_integer(self, column): - return 'SERIAL' if column.auto_increment else 'INTEGER' + return "SERIAL" if column.auto_increment else "INTEGER" def _type_tiny_integer(self, column): - return 'SMALLSERIAL' if column.auto_increment else 'SMALLINT' + return "SMALLSERIAL" if column.auto_increment else "SMALLINT" def _type_small_integer(self, column): - return 'SMALLSERIAL' if column.auto_increment else 'SMALLINT' + return "SMALLSERIAL" if column.auto_increment else "SMALLINT" def _type_float(self, column): return self._type_double(column) def _type_double(self, column): - return 'DOUBLE PRECISION' + return "DOUBLE PRECISION" def _type_decimal(self, column): - return 'DECIMAL(%s, %s)' % (column.total, column.places) + return "DECIMAL(%s, %s)" % (column.total, column.places) def _type_boolean(self, column): - return 'BOOLEAN' + return "BOOLEAN" def _type_enum(self, column): allowed = list(map(lambda a: "'%s'" % a, column.allowed)) - return 'VARCHAR(255) CHECK ("%s" IN (%s))' % (column.name, ', '.join(allowed)) + return 'VARCHAR(255) CHECK ("%s" IN (%s))' % (column.name, ", ".join(allowed)) def _type_json(self, column): - return 'JSON' + return "JSON" def _type_date(self, column): - return 'DATE' + return "DATE" def _type_datetime(self, column): - return 'TIMESTAMP(0) WITHOUT TIME ZONE' + return "TIMESTAMP(6) WITHOUT TIME ZONE" def _type_time(self, column): - return 'TIME(0) WITHOUT TIME ZONE' + return "TIME(6) WITHOUT TIME ZONE" def _type_timestamp(self, column): if column.use_current: - return 'TIMESTAMP(0) WITHOUT TIME ZONE DEFAULT CURRENT_TIMESTAMP(0)' + return "TIMESTAMP(6) WITHOUT TIME ZONE DEFAULT CURRENT_TIMESTAMP(6)" - return 'TIMESTAMP(0) WITHOUT TIME ZONE' + return "TIMESTAMP(6) WITHOUT TIME ZONE" def _type_binary(self, column): - return 'BYTEA' + return "BYTEA" def _modify_nullable(self, blueprint, column): - if column.get('nullable'): - return ' NULL' + if column.get("nullable"): + return " NULL" - return ' NOT NULL' + return " NOT NULL" def _modify_default(self, blueprint, column): - if column.get('default') is not None: - return ' DEFAULT %s' % self._get_default_value(column.default) + if column.get("default") is not None: + return " DEFAULT %s" % self._get_default_value(column.default) - return '' + return "" def _modify_increment(self, blueprint, column): if column.type in self._serials and column.auto_increment: - return ' PRIMARY KEY' + return " PRIMARY KEY" - return '' + return "" def _get_dbal_column_type(self, type_): """ @@ -226,7 +244,7 @@ def _get_dbal_column_type(self, type_): """ type_ = type_.lower() - if type_ == 'enum': - return 'string' + if type_ == "enum": + return "string" return super(PostgresSchemaGrammar, self)._get_dbal_column_type(type_) diff --git a/orator/schema/grammars/sqlite_grammar.py b/orator/schema/grammars/sqlite_grammar.py index 8aeb4971..98780c31 100644 --- a/orator/schema/grammars/sqlite_grammar.py +++ b/orator/schema/grammars/sqlite_grammar.py @@ -8,9 +8,9 @@ class SQLiteSchemaGrammar(SchemaGrammar): - _modifiers = ['nullable', 'default', 'increment'] + _modifiers = ["nullable", "default", "increment"] - _serials = ['big_integer', 'integer'] + _serials = ["big_integer", "integer"] def compile_rename_column(self, blueprint, command, connection): """ @@ -29,16 +29,18 @@ def compile_rename_column(self, blueprint, command, connection): """ sql = [] # If foreign keys are on, we disable them - foreign_keys = self._connection.select('PRAGMA foreign_keys') + foreign_keys = self._connection.select("PRAGMA foreign_keys") if foreign_keys: foreign_keys = bool(foreign_keys[0]) if foreign_keys: - sql.append('PRAGMA foreign_keys = OFF') + sql.append("PRAGMA foreign_keys = OFF") - sql += super(SQLiteSchemaGrammar, self).compile_rename_column(blueprint, command, connection) + sql += super(SQLiteSchemaGrammar, self).compile_rename_column( + blueprint, command, connection + ) if foreign_keys: - sql.append('PRAGMA foreign_keys = ON') + sql.append("PRAGMA foreign_keys = ON") return sql @@ -59,16 +61,18 @@ def compile_change(self, blueprint, command, connection): """ sql = [] # If foreign keys are on, we disable them - foreign_keys = self._connection.select('PRAGMA foreign_keys') + foreign_keys = self._connection.select("PRAGMA foreign_keys") if foreign_keys: foreign_keys = bool(foreign_keys[0]) if foreign_keys: - sql.append('PRAGMA foreign_keys = OFF') + sql.append("PRAGMA foreign_keys = OFF") - sql += super(SQLiteSchemaGrammar, self).compile_change(blueprint, command, connection) + sql += super(SQLiteSchemaGrammar, self).compile_change( + blueprint, command, connection + ) if foreign_keys: - sql.append('PRAGMA foreign_keys = ON') + sql.append("PRAGMA foreign_keys = ON") return sql @@ -78,41 +82,44 @@ def compile_table_exists(self): :rtype: str """ - return "SELECT * FROM sqlite_master WHERE type = 'table' AND name = %(marker)s" % {'marker': self.get_marker()} + return ( + "SELECT * FROM sqlite_master WHERE type = 'table' AND name = %(marker)s" + % {"marker": self.get_marker()} + ) def compile_column_exists(self, table): """ Compile the query to determine the list of columns. """ - return 'PRAGMA table_info(%s)' % table.replace('.', '__') + return "PRAGMA table_info(%s)" % table.replace(".", "__") def compile_create(self, blueprint, command, _): """ Compile a create table command. """ - columns = ', '.join(self._get_columns(blueprint)) + columns = ", ".join(self._get_columns(blueprint)) - sql = 'CREATE TABLE %s (%s' % (self.wrap_table(blueprint), columns) + sql = "CREATE TABLE %s (%s" % (self.wrap_table(blueprint), columns) sql += self._add_foreign_keys(blueprint) sql += self._add_primary_keys(blueprint) - return sql + ')' + return sql + ")" def _add_foreign_keys(self, blueprint): - sql = '' + sql = "" - foreigns = self._get_commands_by_name(blueprint, 'foreign') + foreigns = self._get_commands_by_name(blueprint, "foreign") for foreign in foreigns: sql += self._get_foreign_key(foreign) - if foreign.get('on_delete'): - sql += ' ON DELETE %s' % foreign.on_delete + if foreign.get("on_delete"): + sql += " ON DELETE %s" % foreign.on_delete - if foreign.get('on_update'): - sql += ' ON UPDATE %s' % foreign.on_delete + if foreign.get("on_update"): + sql += " ON UPDATE %s" % foreign.on_delete return sql @@ -127,27 +134,27 @@ def _get_foreign_key(self, foreign): on_columns = self.columnize(references) - return ', FOREIGN KEY(%s) REFERENCES %s(%s)' % (columns, on, on_columns) + return ", FOREIGN KEY(%s) REFERENCES %s(%s)" % (columns, on, on_columns) def _add_primary_keys(self, blueprint): - primary = self._get_command_by_name(blueprint, 'primary') + primary = self._get_command_by_name(blueprint, "primary") if primary: columns = self.columnize(primary.columns) - return ', PRIMARY KEY (%s)' % columns + return ", PRIMARY KEY (%s)" % columns - return '' + return "" def compile_add(self, blueprint, command, _): table = self.wrap_table(blueprint) - columns = self.prefix_list('ADD COLUMN', self._get_columns(blueprint)) + columns = self.prefix_list("ADD COLUMN", self._get_columns(blueprint)) statements = [] for column in columns: - statements.append('ALTER TABLE %s %s' % (table, column)) + statements.append("ALTER TABLE %s %s" % (table, column)) return statements @@ -156,23 +163,23 @@ def compile_unique(self, blueprint, command, _): table = self.wrap_table(blueprint) - return 'CREATE UNIQUE INDEX %s ON %s (%s)' % (command.index, table, columns) + return "CREATE UNIQUE INDEX %s ON %s (%s)" % (command.index, table, columns) def compile_index(self, blueprint, command, _): columns = self.columnize(command.columns) table = self.wrap_table(blueprint) - return 'CREATE INDEX %s ON %s (%s)' % (command.index, table, columns) + return "CREATE INDEX %s ON %s (%s)" % (command.index, table, columns) def compile_foreign(self, blueprint, command, _): pass def compile_drop(self, blueprint, command, _): - return 'DROP TABLE %s' % self.wrap_table(blueprint) + return "DROP TABLE %s" % self.wrap_table(blueprint) def compile_drop_if_exists(self, blueprint, command, _): - return 'DROP TABLE IF EXISTS %s' % self.wrap_table(blueprint) + return "DROP TABLE IF EXISTS %s" % self.wrap_table(blueprint) def compile_drop_column(self, blueprint, command, connection): schema = connection.get_schema_manager() @@ -187,99 +194,99 @@ def compile_drop_column(self, blueprint, command, connection): return schema.get_database_platform().get_alter_table_sql(table_diff) def compile_drop_unique(self, blueprint, command, _): - return 'DROP INDEX %s' % command.index + return "DROP INDEX %s" % command.index def compile_drop_index(self, blueprint, command, _): - return 'DROP INDEX %s' % command.index + return "DROP INDEX %s" % command.index def compile_rename(self, blueprint, command, _): from_ = self.wrap_table(blueprint) - return 'ALTER TABLE %s RENAME TO %s' % (from_, self.wrap_table(command.to)) + return "ALTER TABLE %s RENAME TO %s" % (from_, self.wrap_table(command.to)) def _type_char(self, column): - return 'VARCHAR' + return "VARCHAR" def _type_string(self, column): - return 'VARCHAR' + return "VARCHAR" def _type_text(self, column): - return 'TEXT' + return "TEXT" def _type_medium_text(self, column): - return 'TEXT' + return "TEXT" def _type_long_text(self, column): - return 'TEXT' + return "TEXT" def _type_integer(self, column): - return 'INTEGER' + return "INTEGER" def _type_big_integer(self, column): - return 'INTEGER' + return "INTEGER" def _type_medium_integer(self, column): - return 'INTEGER' + return "INTEGER" def _type_tiny_integer(self, column): - return 'TINYINT' + return "TINYINT" def _type_small_integer(self, column): - return 'INTEGER' + return "INTEGER" def _type_float(self, column): - return 'FLOAT' + return "FLOAT" def _type_double(self, column): - return 'FLOAT' + return "FLOAT" def _type_decimal(self, column): - return 'NUMERIC' + return "NUMERIC" def _type_boolean(self, column): - return 'TINYINT' + return "TINYINT" def _type_enum(self, column): - return 'VARCHAR' + return "VARCHAR" def _type_json(self, column): - return 'TEXT' + return "TEXT" def _type_date(self, column): - return 'DATE' + return "DATE" def _type_datetime(self, column): - return 'DATETIME' + return "DATETIME" def _type_time(self, column): - return 'TIME' + return "TIME" def _type_timestamp(self, column): if column.use_current: - return 'DATETIME DEFAULT CURRENT_TIMESTAMP' + return "DATETIME DEFAULT CURRENT_TIMESTAMP" - return 'DATETIME' + return "DATETIME" def _type_binary(self, column): - return 'BLOB' + return "BLOB" def _modify_nullable(self, blueprint, column): - if column.get('nullable'): - return ' NULL' + if column.get("nullable"): + return " NULL" - return ' NOT NULL' + return " NOT NULL" def _modify_default(self, blueprint, column): - if column.get('default') is not None: - return ' DEFAULT %s' % self._get_default_value(column.default) + if column.get("default") is not None: + return " DEFAULT %s" % self._get_default_value(column.default) - return '' + return "" def _modify_increment(self, blueprint, column): if column.type in self._serials and column.auto_increment: - return ' PRIMARY KEY AUTOINCREMENT' + return " PRIMARY KEY AUTOINCREMENT" - return '' + return "" def _get_dbal_column_type(self, type_): """ @@ -292,7 +299,7 @@ def _get_dbal_column_type(self, type_): """ type_ = type_.lower() - if type_ == 'enum': - return 'string' + if type_ == "enum": + return "string" return super(SQLiteSchemaGrammar, self)._get_dbal_column_type(type_) diff --git a/orator/schema/mysql_builder.py b/orator/schema/mysql_builder.py index 7d54d34f..62d95e5b 100644 --- a/orator/schema/mysql_builder.py +++ b/orator/schema/mysql_builder.py @@ -4,7 +4,6 @@ class MySQLSchemaBuilder(SchemaBuilder): - def has_table(self, table): """ Determine if the given table exists. @@ -33,6 +32,12 @@ def get_column_listing(self, table): database = self._connection.get_database_name() table = self._connection.get_table_prefix() + table - results = self._connection.select(sql, [database, table]) + results = [] + for result in self._connection.select(sql, [database, table]): + new_result = {} + for key, value in result.items(): + new_result[key.lower()] = value + + results.append(new_result) return self._connection.get_post_processor().process_column_listing(results) diff --git a/orator/schema/schema.py b/orator/schema/schema.py index 98b4b67f..e95c29ec 100644 --- a/orator/schema/schema.py +++ b/orator/schema/schema.py @@ -2,7 +2,6 @@ class Schema(object): - def __init__(self, manager): """ :param manager: The database manager diff --git a/orator/seeds/seeder.py b/orator/seeds/seeder.py index 978b4c67..911e6f64 100644 --- a/orator/seeds/seeder.py +++ b/orator/seeds/seeder.py @@ -32,7 +32,7 @@ def call(self, klass): self._resolve(klass).run() if self._command: - self._command.line('Seeded: %s' % klass.__name__) + self._command.line("Seeded: %s" % klass.__name__) def _resolve(self, klass): """ diff --git a/orator/support/fluent.py b/orator/support/fluent.py index f01e4ba8..d0dc388a 100644 --- a/orator/support/fluent.py +++ b/orator/support/fluent.py @@ -29,7 +29,6 @@ def __set_value(self, value): class Fluent(object): - def __init__(self, **attributes): self._attributes = {} @@ -78,7 +77,7 @@ def __getattr__(self, item): return Dynamic(self._attributes.get(item), item, self) def __setattr__(self, key, value): - if key == '_attributes': + if key == "_attributes": super(Fluent, self).__setattr__(key, value) try: diff --git a/orator/support/grammar.py b/orator/support/grammar.py index 7d9dc9ee..93ac7bd9 100644 --- a/orator/support/grammar.py +++ b/orator/support/grammar.py @@ -5,10 +5,10 @@ class Grammar(object): - marker = '?' + marker = "?" def __init__(self, marker=None): - self._table_prefix = '' + self._table_prefix = "" if marker: self.marker = marker @@ -30,18 +30,17 @@ def wrap(self, value, prefix_alias=False): # to separate out the pieces so we can wrap each of the segments # of the expression on it own, and then joins them # both back together with the "as" connector. - if value.lower().find(' as ') >= 0: - segments = value.split(' ') + if value.lower().find(" as ") >= 0: + segments = value.split(" ") if prefix_alias: segments[2] = self._table_prefix + segments[2] - return '%s AS %s' % (self.wrap(segments[0]), - self._wrap_value(segments[2])) + return "%s AS %s" % (self.wrap(segments[0]), self._wrap_value(segments[2])) wrapped = [] - segments = value.split('.') + segments = value.split(".") # If the value is not an aliased table expression, we'll just wrap it like # normal, so if there is more than one segment, we will wrap the first @@ -52,19 +51,19 @@ def wrap(self, value, prefix_alias=False): else: wrapped.append(self._wrap_value(segment)) - return '.'.join(wrapped) + return ".".join(wrapped) def _wrap_value(self, value): - if value == '*': + if value == "*": return value return '"%s"' % value.replace('"', '""') def columnize(self, columns): - return ', '.join(map(self.wrap, columns)) + return ", ".join(map(self.wrap, columns)) def parameterize(self, values): - return ', '.join(map(self.parameter, values)) + return ", ".join(map(self.parameter, values)) def parameter(self, value): if self.is_expression(value): @@ -79,7 +78,7 @@ def is_expression(self, value): return isinstance(value, QueryExpression) def get_date_format(self): - return 'Y-m-d H:i:s' + return "%Y-%m-%d %H:%M:%S.%f" def get_table_prefix(self): return self._table_prefix diff --git a/orator/utils/__init__.py b/orator/utils/__init__.py index ac9d385e..e5a70c3f 100644 --- a/orator/utils/__init__.py +++ b/orator/utils/__init__.py @@ -21,10 +21,12 @@ from urlparse import parse_qsl def load_module(module, path): - with open(path, 'rb') as fh: + with open(path, "rb") as fh: mod = imp.load_source(module, path, fh) return mod + + else: long = int unicode = str @@ -32,21 +34,19 @@ def load_module(module, path): from functools import reduce - from urllib.parse import (quote_plus, unquote_plus, - parse_qsl, quote, unquote) + from urllib.parse import quote_plus, unquote_plus, parse_qsl, quote, unquote if PY33: from importlib import machinery def load_module(module, path): - return machinery.SourceFileLoader( - module, path - ).load_module(module) + return machinery.SourceFileLoader(module, path).load_module(module) + else: import imp def load_module(module, path): - with open(path, 'rb') as fh: + with open(path, "rb") as fh: mod = imp.load_source(module, path, fh) return mod @@ -56,7 +56,6 @@ def load_module(module, path): class Null(object): - def __bool__(self): return False @@ -65,9 +64,9 @@ def __eq__(self, other): def deprecated(func): - '''This is a decorator which can be used to mark functions + """This is a decorator which can be used to mark functions as deprecated. It will result in a warning being emitted - when the function is used.''' + when the function is used.""" @functools.wraps(func) def new_func(*args, **kwargs): @@ -80,7 +79,7 @@ def new_func(*args, **kwargs): "Call to deprecated function {}.".format(func.__name__), category=DeprecationWarning, filename=func_code.co_filename, - lineno=func_code.co_firstlineno + 1 + lineno=func_code.co_firstlineno + 1, ) return func(*args, **kwargs) @@ -96,7 +95,7 @@ def decode(string, encodings=None): return string if encodings is None: - encodings = ['utf-8', 'latin1', 'ascii'] + encodings = ["utf-8", "latin1", "ascii"] for encoding in encodings: try: @@ -104,7 +103,7 @@ def decode(string, encodings=None): except (UnicodeEncodeError, UnicodeDecodeError): pass - return string.decode(encodings[0], errors='ignore') + return string.decode(encodings[0], errors="ignore") def encode(string, encodings=None): @@ -115,7 +114,7 @@ def encode(string, encodings=None): return string if encodings is None: - encodings = ['utf-8', 'latin1', 'ascii'] + encodings = ["utf-8", "latin1", "ascii"] for encoding in encodings: try: @@ -123,4 +122,4 @@ def encode(string, encodings=None): except (UnicodeEncodeError, UnicodeDecodeError): pass - return string.encode(encodings[0], errors='ignore') + return string.encode(encodings[0], errors="ignore") diff --git a/orator/utils/command_formatter.py b/orator/utils/command_formatter.py index f2b28b7c..24f7cb77 100644 --- a/orator/utils/command_formatter.py +++ b/orator/utils/command_formatter.py @@ -1,41 +1,48 @@ # -*- coding: utf-8 -*- from pygments.formatter import Formatter -from pygments.token import Keyword, Name, Comment, String, Error, \ - Number, Operator, Generic, Token, Whitespace +from pygments.token import ( + Keyword, + Name, + Comment, + String, + Error, + Number, + Operator, + Generic, + Token, + Whitespace, +) from pygments.util import get_choice_opt COMMAND_COLORS = { - Token: ('', ''), - - Whitespace: ('fg=white', 'fg=black;options=bold'), - Comment: ('fg=white', 'fg=black;options=bold'), - Comment.Preproc: ('fg=cyan', 'fg=cyan;options=bold'), - Keyword: ('fg=blue', 'fg=blue;options=bold'), - Keyword.Type: ('fg=cyan', 'fg=cyan;options=bold'), - Operator.Word: ('fg=magenta', 'fg=magenta;options=bold'), - Name.Builtin: ('fg=cyan', 'fg=cyan;options=bold'), - Name.Function: ('fg=green', 'fg=green;option=bold'), - Name.Namespace: ('fg=cyan;options=underline', 'fg=cyan;options=bold,underline'), - Name.Class: ('fg=green;options=underline', 'fg=green;options=bold,underline'), - Name.Exception: ('fg=cyan', 'fg=cyan;options=bold'), - Name.Decorator: ('fg=black;options=bold', 'fg=white'), - Name.Variable: ('fg=red', 'fg=red;options=bold'), - Name.Constant: ('fg=red', 'fg=red;options=bold'), - Name.Attribute: ('fg=cyan', 'fg=cyan;options=bold'), - Name.Tag: ('fg=blue;options=bold', 'fg=blue;options=bold'), - String: ('fg=yellow', 'fg=yellow'), - Number: ('fg=blue', 'fg=blue;options=bold'), - - Generic.Deleted: ('fg=red;options=bold', 'fg=red;options=bold'), - Generic.Inserted: ('fg=green', 'fg=green;options=bold'), - Generic.Heading: ('options=bold', 'option=bold'), - Generic.Subheading: ('fg=magenta;options=bold', 'fg=magenta;options=bold'), - Generic.Prompt: ('options=bold', 'options=bold'), - Generic.Error: ('fg=red;options=bold', 'fg=red;options=bold'), - - Error: ('fg=red;options=bold,underline', 'fg=red;options=bold,underline'), + Token: ("", ""), + Whitespace: ("fg=white", "fg=black;options=bold"), + Comment: ("fg=white", "fg=black;options=bold"), + Comment.Preproc: ("fg=cyan", "fg=cyan;options=bold"), + Keyword: ("fg=blue", "fg=blue;options=bold"), + Keyword.Type: ("fg=cyan", "fg=cyan;options=bold"), + Operator.Word: ("fg=magenta", "fg=magenta;options=bold"), + Name.Builtin: ("fg=cyan", "fg=cyan;options=bold"), + Name.Function: ("fg=green", "fg=green;option=bold"), + Name.Namespace: ("fg=cyan;options=underline", "fg=cyan;options=bold,underline"), + Name.Class: ("fg=green;options=underline", "fg=green;options=bold,underline"), + Name.Exception: ("fg=cyan", "fg=cyan;options=bold"), + Name.Decorator: ("fg=black;options=bold", "fg=white"), + Name.Variable: ("fg=red", "fg=red;options=bold"), + Name.Constant: ("fg=red", "fg=red;options=bold"), + Name.Attribute: ("fg=cyan", "fg=cyan;options=bold"), + Name.Tag: ("fg=blue;options=bold", "fg=blue;options=bold"), + String: ("fg=yellow", "fg=yellow"), + Number: ("fg=blue", "fg=blue;options=bold"), + Generic.Deleted: ("fg=red;options=bold", "fg=red;options=bold"), + Generic.Inserted: ("fg=green", "fg=green;options=bold"), + Generic.Heading: ("options=bold", "option=bold"), + Generic.Subheading: ("fg=magenta;options=bold", "fg=magenta;options=bold"), + Generic.Prompt: ("options=bold", "options=bold"), + Generic.Error: ("fg=red;options=bold", "fg=red;options=bold"), + Error: ("fg=red;options=bold,underline", "fg=red;options=bold,underline"), } @@ -62,16 +69,17 @@ class CommandFormatter(Formatter): Set to ``True`` to have line numbers on the terminal output as well (default: ``False`` = no line numbers). """ - name = 'Command' - aliases = ['command'] + name = "Command" + aliases = ["command"] filenames = [] def __init__(self, **options): Formatter.__init__(self, **options) - self.darkbg = get_choice_opt(options, 'bg', - ['light', 'dark'], 'light') == 'dark' - self.colorscheme = options.get('colorscheme', None) or COMMAND_COLORS - self.linenos = options.get('linenos', False) + self.darkbg = ( + get_choice_opt(options, "bg", ["light", "dark"], "light") == "dark" + ) + self.colorscheme = options.get("colorscheme", None) or COMMAND_COLORS + self.linenos = options.get("linenos", False) self._lineno = 0 def format(self, tokensource, outfile): @@ -79,7 +87,7 @@ def format(self, tokensource, outfile): def _write_lineno(self, outfile): self._lineno += 1 - outfile.write("%s%04d: " % (self._lineno != 1 and '\n' or '', self._lineno)) + outfile.write("%s%04d: " % (self._lineno != 1 and "\n" or "", self._lineno)) def _get_color(self, ttype): # self.colorscheme is a dict containing usually generic types, so we @@ -100,14 +108,14 @@ def format_unencoded(self, tokensource, outfile): for line in value.splitlines(True): if color: - outfile.write('<%s>%s' % (color, line.rstrip('\n'))) + outfile.write("<%s>%s" % (color, line.rstrip("\n"))) else: - outfile.write(line.rstrip('\n')) - if line.endswith('\n'): + outfile.write(line.rstrip("\n")) + if line.endswith("\n"): if self.linenos: self._write_lineno(outfile) else: - outfile.write('\n') + outfile.write("\n") if self.linenos: outfile.write("\n") diff --git a/orator/utils/helpers.py b/orator/utils/helpers.py index 27b11b69..1b54f44a 100644 --- a/orator/utils/helpers.py +++ b/orator/utils/helpers.py @@ -2,6 +2,7 @@ import os import errno +import datetime def value(val): @@ -19,3 +20,18 @@ def mkdir_p(path, mode=0o777): pass else: raise + + +def serialize(value): + if isinstance(value, datetime.datetime): + if hasattr(value, "to_json"): + value = value.to_json() + else: + value = value.isoformat() + elif isinstance(value, list): + value = list(map(serialize, value)) + elif isinstance(value, dict): + for k, v in value.items(): + value[k] = serialize(v) + + return value diff --git a/orator/utils/qmarker.py b/orator/utils/qmarker.py index 5ee784b5..d6b2a19e 100644 --- a/orator/utils/qmarker.py +++ b/orator/utils/qmarker.py @@ -5,21 +5,22 @@ class Qmarker(object): - RE_QMARK = re.compile(r'\?\?|\?|%') + RE_QMARK = re.compile(r"\?\?|\?|%") @classmethod def qmark(cls, query): """ Convert a "qmark" query into "format" style. """ + def sub_sequence(m): s = m.group(0) - if s == '??': - return '?' - if s == '%': - return '%%' + if s == "??": + return "?" + if s == "%": + return "%%" else: - return '%s' + return "%s" return cls.RE_QMARK.sub(sub_sequence, query) diff --git a/orator/utils/url.py b/orator/utils/url.py index dc04fb10..19bd7f79 100644 --- a/orator/utils/url.py +++ b/orator/utils/url.py @@ -44,8 +44,16 @@ class URL(object): """ - def __init__(self, drivername, username=None, password=None, - host=None, port=None, database=None, query=None): + def __init__( + self, + drivername, + username=None, + password=None, + host=None, + port=None, + database=None, + query=None, + ): self.drivername = drivername self.username = username self.password = password @@ -62,22 +70,21 @@ def __to_string__(self, hide_password=True): if self.username is not None: s += _rfc_1738_quote(self.username) if self.password is not None: - s += ':' + ('***' if hide_password - else _rfc_1738_quote(self.password)) + s += ":" + ("***" if hide_password else _rfc_1738_quote(self.password)) s += "@" if self.host is not None: - if ':' in self.host: + if ":" in self.host: s += "[%s]" % self.host else: s += self.host if self.port is not None: - s += ':' + str(self.port) + s += ":" + str(self.port) if self.database is not None: - s += '/' + self.database + s += "/" + self.database if self.query: keys = list(self.query) keys.sort() - s += '?' + "&".join("%s=%s" % (k, self.query[k]) for k in keys) + s += "?" + "&".join("%s=%s" % (k, self.query[k]) for k in keys) return s def __str__(self): @@ -90,43 +97,46 @@ def __hash__(self): return hash(str(self)) def __eq__(self, other): - return \ - isinstance(other, URL) and \ - self.drivername == other.drivername and \ - self.username == other.username and \ - self.password == other.password and \ - self.host == other.host and \ - self.database == other.database and \ - self.query == other.query + return ( + isinstance(other, URL) + and self.drivername == other.drivername + and self.username == other.username + and self.password == other.password + and self.host == other.host + and self.database == other.database + and self.query == other.query + ) def get_backend_name(self): - if '+' not in self.drivername: + if "+" not in self.drivername: return self.drivername else: - return self.drivername.split('+')[0] + return self.drivername.split("+")[0] def get_driver_name(self): - if '+' not in self.drivername: + if "+" not in self.drivername: return self.get_dialect().driver else: - return self.drivername.split('+')[1] + return self.drivername.split("+")[1] def get_dialect(self): """Return the SQLAlchemy database dialect class corresponding to this URL's driver name. """ - if '+' not in self.drivername: + if "+" not in self.drivername: name = self.drivername else: - name = self.drivername.replace('+', '.') + name = self.drivername.replace("+", ".") cls = registry.load(name) # check for legacy dialects that # would return a module with 'dialect' as the # actual class - if hasattr(cls, 'dialect') and \ - isinstance(cls.dialect, type) and \ - issubclass(cls.dialect, Dialect): + if ( + hasattr(cls, "dialect") + and isinstance(cls.dialect, type) + and issubclass(cls.dialect, Dialect) + ): return cls.dialect else: return cls @@ -146,7 +156,7 @@ def translate_connect_args(self, names=[], **kw): """ translated = {} - attribute_names = ['host', 'database', 'username', 'password', 'port'] + attribute_names = ["host", "database", "username", "password", "port"] for sname in attribute_names: if names: name = names.pop(0) @@ -173,7 +183,8 @@ def make_url(name_or_url): def _parse_rfc1738_args(name): - pattern = re.compile(r''' + pattern = re.compile( + r""" (?P[\w\+]+):// (?: (?P[^:/]*) @@ -187,40 +198,40 @@ def _parse_rfc1738_args(name): (?::(?P[^/]*))? )? (?:/(?P.*))? - ''', re.X) + """, + re.X, + ) m = pattern.match(name) if m is not None: components = m.groupdict() - if components['database'] is not None: - tokens = components['database'].split('?', 2) - components['database'] = tokens[0] - query = ( - len(tokens) > 1 and dict(parse_qsl(tokens[1]))) or None + if components["database"] is not None: + tokens = components["database"].split("?", 2) + components["database"] = tokens[0] + query = (len(tokens) > 1 and dict(parse_qsl(tokens[1]))) or None if PY2 and query is not None: - query = dict((k.encode('ascii'), query[k]) for k in query) + query = dict((k.encode("ascii"), query[k]) for k in query) else: query = None - components['query'] = query + components["query"] = query - if components['username'] is not None: - components['username'] = _rfc_1738_unquote(components['username']) + if components["username"] is not None: + components["username"] = _rfc_1738_unquote(components["username"]) - if components['password'] is not None: - components['password'] = _rfc_1738_unquote(components['password']) + if components["password"] is not None: + components["password"] = _rfc_1738_unquote(components["password"]) - ipv4host = components.pop('ipv4host') - ipv6host = components.pop('ipv6host') - components['host'] = ipv4host or ipv6host - name = components.pop('name') + ipv4host = components.pop("ipv4host") + ipv6host = components.pop("ipv6host") + components["host"] = ipv4host or ipv6host + name = components.pop("name") return URL(name, **components) else: - raise ArgumentError( - "Could not parse rfc1738 URL from string '%s'" % name) + raise ArgumentError("Could not parse rfc1738 URL from string '%s'" % name) def _rfc_1738_quote(text): - return re.sub(r'[:@/]', lambda m: "%%%X" % ord(m.group(0)), text) + return re.sub(r"[:@/]", lambda m: "%%%X" % ord(m.group(0)), text) def _rfc_1738_unquote(text): @@ -228,12 +239,10 @@ def _rfc_1738_unquote(text): def _parse_keyvalue_args(name): - m = re.match(r'(\w+)://(.*)', name) + m = re.match(r"(\w+)://(.*)", name) if m is not None: (name, args) = m.group(1, 2) opts = dict(parse_qsl(args)) return URL(name, *opts) else: return None - - diff --git a/orator/version.py b/orator/version.py deleted file mode 100644 index 06e67340..00000000 --- a/orator/version.py +++ /dev/null @@ -1,3 +0,0 @@ -# -*- coding: utf-8 -*- - -VERSION = '0.9.2' diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..c067e646 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,62 @@ +[tool.poetry] +name = "orator" +version = "0.9.9" +description = "The Orator ORM provides a simple yet beautiful ActiveRecord implementation." + +license = "MIT" + +authors = [ + "Sébastien Eustace " +] + +readme = 'README.rst' + +repository = "https://github.com/sdispater/orator" +homepage = "https://orator-orm.com/" + +keywords = ['database', 'orm'] + + +[tool.poetry.dependencies] +python = "~2.7 || ^3.5" +backpack = "^0.1" +blinker = "^1.4" +cleo = "^0.6" +inflection = "^0.3" +Faker = "^0.8" +lazy-object-proxy = "^1.2" +pendulum = "^1.4" +pyaml = "^16.12" +pyyaml = "^5.1" +Pygments = "^2.2" +simplejson = "^3.10" +six = "^1.10" +wrapt = "^1.10" + +# Extras +psycopg2 = { version = "^2.7", optional = true } +PyMySQL = { version = "^0.7", optional = true } +mysqlclient = { version = "^1.3", optional = true } + + +[tool.poetry.dev-dependencies] +flexmock = "0.9.7" +pytest = "^3.5" +pytest-mock = "^1.6" +tox = "^3.5" +pre-commit = "^1.11" + + +[tool.poetry.extras] +mysql = ["mysqlclient"] +mysql-python = ["PyMySQL"] +pgsql = ["psycopg2"] + + +[tool.poetry.scripts] +orator = 'orator.commands.application:application.run' + + +[build-system] +requires = ["poetry>=0.12a3"] +build-backend = "poetry.masonry.api" diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index dee82e7b..00000000 --- a/requirements.txt +++ /dev/null @@ -1,12 +0,0 @@ -simplejson -pendulum -backpack -inflection -six -cleo>=0.4.1 -blinker -lazy-object-proxy -fake-factory -wrapt -pyaml -pygments diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index 2a9acf13..00000000 --- a/setup.cfg +++ /dev/null @@ -1,2 +0,0 @@ -[bdist_wheel] -universal = 1 diff --git a/setup.py b/setup.py deleted file mode 100644 index 3f67d89c..00000000 --- a/setup.py +++ /dev/null @@ -1,51 +0,0 @@ -# -*- coding: utf-8 -*- - -import os -from setuptools import find_packages -from distutils.core import setup - - -here = os.path.abspath(os.path.dirname(__file__)) - -def get_version(): - with open(os.path.join(here, 'orator/version.py')) as f: - variables = {} - exec(f.read(), variables) - - version = variables.get('VERSION') - if version: - return version - - raise RuntimeError('No version info found.') - -__version__ = get_version() - -with open(os.path.join(here, 'requirements.txt')) as f: - requirements = f.readlines() - -setup_kwargs = dict( - name='orator', - license='MIT', - version=__version__, - description='The Orator ORM provides a simple yet beautiful ActiveRecord implementation.', - long_description=open('README.rst').read(), - entry_points={ - 'console_scripts': ['orator=orator.commands.application:application.run'], - }, - author='Sébastien Eustace', - author_email='sebastien.eustace@gmail.com', - url='https://github.com/sdispater/orator', - download_url='https://github.com/sdispater/orator/archive/%s.tar.gz' % __version__, - packages=find_packages(exclude=['tests']), - install_requires=requirements, - tests_require=['pytest', 'mock', 'flexmock==0.9.7', 'mysqlclient', 'psycopg2'], - test_suite='nose.collector', - classifiers=[ - 'Intended Audience :: Developers', - 'Operating System :: OS Independent', - 'Programming Language :: Python', - 'Topic :: Software Development :: Libraries :: Python Modules', - ], -) - -setup(**setup_kwargs) diff --git a/tests-requirements.txt b/tests-requirements.txt deleted file mode 100644 index 843db42e..00000000 --- a/tests-requirements.txt +++ /dev/null @@ -1,5 +0,0 @@ --r requirements.txt -pytest -pytest-mock -flexmock==0.9.7 -psycopg2 diff --git a/tests/__init__.py b/tests/__init__.py index f59d6aae..fbdad4b4 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -15,34 +15,32 @@ class OratorTestCase(TestCase): - def tearDown(self): - if hasattr(self, 'local_database'): + if hasattr(self, "local_database"): os.remove(self.local_database) def init_database(self): - self.local_database = '/tmp/orator_test_database.db' + self.local_database = "/tmp/orator_test_database.db" if os.path.exists(self.local_database): os.remove(self.local_database) - self.manager = DatabaseManager({ - 'default': 'sqlite', - 'sqlite': { - 'driver': 'sqlite', - 'database': self.local_database + self.manager = DatabaseManager( + { + "default": "sqlite", + "sqlite": {"driver": "sqlite", "database": self.local_database}, } - }) + ) with self.manager.transaction(): try: self.manager.statement( - 'CREATE TABLE `users` (' - 'id INTEGER PRIMARY KEY NOT NULL, ' - 'name CHAR(50) NOT NULL, ' - 'created_at DATETIME DEFAULT CURRENT_TIMESTAMP, ' - 'updated_at DATETIME DEFAULT CURRENT_TIMESTAMP' - ')' + "CREATE TABLE `users` (" + "id INTEGER PRIMARY KEY NOT NULL, " + "name CHAR(50) NOT NULL, " + "created_at DATETIME DEFAULT CURRENT_TIMESTAMP, " + "updated_at DATETIME DEFAULT CURRENT_TIMESTAMP" + ")" ) except Exception: pass diff --git a/tests/commands/__init__.py b/tests/commands/__init__.py index d6cd4d88..18658674 100644 --- a/tests/commands/__init__.py +++ b/tests/commands/__init__.py @@ -6,11 +6,10 @@ class OratorCommandTestCase(OratorTestCase): - def tearDown(self): flexmock_teardown() - def run_command(self, command, options=None, input_stream=None): + def run_command(self, command, options=None): """ Run the command. @@ -20,18 +19,12 @@ def run_command(self, command, options=None, input_stream=None): if options is None: options = [] - options = [('command', command.get_name())] + options + options = [("command", command.get_name())] + options application = Application() application.add(command) - if input_stream: - dialog = command.get_helper('question') - dialog.__class__.input_stream = input_stream - command_tester = CommandTester(command) command_tester.execute(options) return command_tester - - diff --git a/tests/commands/migrations/__init__.py b/tests/commands/migrations/__init__.py index 633f8661..40a96afc 100644 --- a/tests/commands/migrations/__init__.py +++ b/tests/commands/migrations/__init__.py @@ -1,2 +1 @@ # -*- coding: utf-8 -*- - diff --git a/tests/commands/migrations/test_install_command.py b/tests/commands/migrations/test_install_command.py index fc3f4fc1..4a8a676f 100644 --- a/tests/commands/migrations/test_install_command.py +++ b/tests/commands/migrations/test_install_command.py @@ -7,13 +7,12 @@ class MigrateInstallCommandTestCase(OratorCommandTestCase): - def test_execute_calls_repository_to_install(self): repo_mock = flexmock(DatabaseMigrationRepository) - repo_mock.should_receive('set_source').once().with_args('foo') - repo_mock.should_receive('create_repository').once() + repo_mock.should_receive("set_source").once().with_args("foo") + repo_mock.should_receive("create_repository").once() command = flexmock(InstallCommand()) - command.should_receive('_get_config').and_return({}) + command.should_receive("_get_config").and_return({}) - self.run_command(command, [('--database', 'foo')]) + self.run_command(command, [("--database", "foo")]) diff --git a/tests/commands/migrations/test_make_command.py b/tests/commands/migrations/test_make_command.py index 4b3181bd..b72269eb 100644 --- a/tests/commands/migrations/test_make_command.py +++ b/tests/commands/migrations/test_make_command.py @@ -8,33 +8,39 @@ class MigrateMakeCommandTestCase(OratorCommandTestCase): - def test_basic_create_gives_creator_proper_arguments(self): creator_mock = flexmock(MigrationCreator) - creator_mock.should_receive('create').once()\ - .with_args('create_foo', os.path.join(os.getcwd(), 'migrations'), None, False).and_return('foo') + creator_mock.should_receive("create").once().with_args( + "create_foo", os.path.join(os.getcwd(), "migrations"), None, False + ).and_return("foo") command = flexmock(MigrateMakeCommand()) - command.should_receive('_get_config').and_return({}) + command.should_receive("_get_config").and_return({}) - self.run_command(command, [('name', 'create_foo')]) + self.run_command(command, [("name", "create_foo")]) def test_basic_create_gives_creator_proper_arguments_when_table_is_set(self): creator_mock = flexmock(MigrationCreator) - creator_mock.should_receive('create').once()\ - .with_args('create_foo', os.path.join(os.getcwd(), 'migrations'), 'users', False).and_return('foo') + creator_mock.should_receive("create").once().with_args( + "create_foo", os.path.join(os.getcwd(), "migrations"), "users", False + ).and_return("foo") command = flexmock(MigrateMakeCommand()) - command.should_receive('_get_config').and_return({}) + command.should_receive("_get_config").and_return({}) - self.run_command(command, [('name', 'create_foo'), ('--table', 'users')]) + self.run_command(command, [("name", "create_foo"), ("--table", "users")]) - def test_basic_create_gives_creator_proper_arguments_when_table_is_set_with_create(self): + def test_basic_create_gives_creator_proper_arguments_when_table_is_set_with_create( + self + ): creator_mock = flexmock(MigrationCreator) - creator_mock.should_receive('create').once()\ - .with_args('create_foo', os.path.join(os.getcwd(), 'migrations'), 'users', True).and_return('foo') + creator_mock.should_receive("create").once().with_args( + "create_foo", os.path.join(os.getcwd(), "migrations"), "users", True + ).and_return("foo") command = flexmock(MigrateMakeCommand()) - command.should_receive('_get_config').and_return({}) + command.should_receive("_get_config").and_return({}) - self.run_command(command, [('name', 'create_foo'), ('--table', 'users'), '--create']) + self.run_command( + command, [("name", "create_foo"), ("--table", "users"), "--create"] + ) diff --git a/tests/commands/migrations/test_migrate_command.py b/tests/commands/migrations/test_migrate_command.py index a88cc1a3..4a69058a 100644 --- a/tests/commands/migrations/test_migrate_command.py +++ b/tests/commands/migrations/test_migrate_command.py @@ -1,9 +1,7 @@ # -*- coding: utf-8 -*- import os -from io import BytesIO from flexmock import flexmock -from cleo import Output from orator.migrations import Migrator from orator.commands.migrations import MigrateCommand from orator import DatabaseManager @@ -11,74 +9,97 @@ class MigrateCommandTestCase(OratorCommandTestCase): - def test_basic_migrations_call_migrator_with_proper_arguments(self): resolver = flexmock(DatabaseManager) - resolver.should_receive('connection').and_return(None) + resolver.should_receive("connection").and_return(None) migrator_mock = flexmock(Migrator) - migrator_mock.should_receive('set_connection').once().with_args(None) - migrator_mock.should_receive('run').once().with_args(os.path.join(os.getcwd(), 'migrations'), False) - migrator_mock.should_receive('get_notes').and_return([]) - migrator_mock.should_receive('repository_exists').once().and_return(True) + migrator_mock.should_receive("set_connection").once().with_args(None) + migrator_mock.should_receive("run").once().with_args( + os.path.join(os.getcwd(), "migrations"), False + ) + migrator_mock.should_receive("get_notes").and_return([]) + migrator_mock.should_receive("repository_exists").once().and_return(True) command = flexmock(MigrateCommand()) - command.should_receive('_get_config').and_return({}) + command.should_receive("_get_config").and_return({}) + command.should_receive("confirm").and_return(True) - self.run_command(command, input_stream=self.get_input_stream('y\n')) + self.run_command(command) def test_migration_repository_create_when_necessary(self): resolver = flexmock(DatabaseManager) - resolver.should_receive('connection').and_return(None) + resolver.should_receive("connection").and_return(None) migrator_mock = flexmock(Migrator) - migrator_mock.should_receive('set_connection').once().with_args(None) - migrator_mock.should_receive('run').once().with_args(os.path.join(os.getcwd(), 'migrations'), False) - migrator_mock.should_receive('get_notes').and_return([]) - migrator_mock.should_receive('repository_exists').once().and_return(False) + migrator_mock.should_receive("set_connection").once().with_args(None) + migrator_mock.should_receive("run").once().with_args( + os.path.join(os.getcwd(), "migrations"), False + ) + migrator_mock.should_receive("get_notes").and_return([]) + migrator_mock.should_receive("repository_exists").once().and_return(False) command = flexmock(MigrateCommand()) - command.should_receive('_get_config').and_return({}) - command.should_receive('call').once()\ - .with_args('migrate:install', [('--config', None)]) + command.should_receive("_get_config").and_return({}) + command.should_receive("confirm").and_return(True) + command.should_receive("call").once().with_args( + "migrate:install", [("--config", None)] + ) - self.run_command(command, input_stream=self.get_input_stream('y\n')) + self.run_command(command) def test_migration_can_be_pretended(self): resolver = flexmock(DatabaseManager) - resolver.should_receive('connection').and_return(None) + resolver.should_receive("connection").and_return(None) migrator_mock = flexmock(Migrator) - migrator_mock.should_receive('set_connection').once().with_args(None) - migrator_mock.should_receive('run').once().with_args(os.path.join(os.getcwd(), 'migrations'), True) - migrator_mock.should_receive('get_notes').and_return([]) - migrator_mock.should_receive('repository_exists').once().and_return(True) + migrator_mock.should_receive("set_connection").once().with_args(None) + migrator_mock.should_receive("run").once().with_args( + os.path.join(os.getcwd(), "migrations"), True + ) + migrator_mock.should_receive("get_notes").and_return([]) + migrator_mock.should_receive("repository_exists").once().and_return(True) command = flexmock(MigrateCommand()) - command.should_receive('_get_config').and_return({}) + command.should_receive("_get_config").and_return({}) + command.should_receive("confirm").and_return(True) - self.run_command(command, [('--pretend', True)], input_stream=self.get_input_stream('y\n')) + self.run_command(command, [("--pretend", True)]) def test_migration_database_can_be_set(self): resolver = flexmock(DatabaseManager) - resolver.should_receive('connection').and_return(None) + resolver.should_receive("connection").and_return(None) migrator_mock = flexmock(Migrator) - migrator_mock.should_receive('set_connection').once().with_args('foo') - migrator_mock.should_receive('run').once().with_args(os.path.join(os.getcwd(), 'migrations'), False) - migrator_mock.should_receive('get_notes').and_return([]) - migrator_mock.should_receive('repository_exists').once().and_return(False) + migrator_mock.should_receive("set_connection").once().with_args("foo") + migrator_mock.should_receive("run").once().with_args( + os.path.join(os.getcwd(), "migrations"), False + ) + migrator_mock.should_receive("get_notes").and_return([]) + migrator_mock.should_receive("repository_exists").once().and_return(False) command = flexmock(MigrateCommand()) - command.should_receive('_get_config').and_return({}) - command.should_receive('call').once()\ - .with_args('migrate:install', [('--database', 'foo'), ('--config', None)]) + command.should_receive("_get_config").and_return({}) + command.should_receive("confirm").and_return(True) + command.should_receive("call").once().with_args( + "migrate:install", [("--database", "foo"), ("--config", None)] + ) + + self.run_command(command, [("--database", "foo")]) + + def test_migration_can_be_forced(self): + resolver = flexmock(DatabaseManager) + resolver.should_receive("connection").and_return(None) - self.run_command(command, [('--database', 'foo')], input_stream=self.get_input_stream('y\n')) + migrator_mock = flexmock(Migrator) + migrator_mock.should_receive("set_connection").once().with_args(None) + migrator_mock.should_receive("run").once().with_args( + os.path.join(os.getcwd(), "migrations"), False + ) + migrator_mock.should_receive("get_notes").and_return([]) + migrator_mock.should_receive("repository_exists").once().and_return(True) - def get_input_stream(self, input_): - stream = BytesIO() - stream.write(input_.encode()) - stream.seek(0) + command = flexmock(MigrateCommand()) + command.should_receive("_get_config").and_return({}) - return stream + self.run_command(command, [("--force", True)]) diff --git a/tests/commands/migrations/test_refresh_command.py b/tests/commands/migrations/test_refresh_command.py new file mode 100644 index 00000000..657141e8 --- /dev/null +++ b/tests/commands/migrations/test_refresh_command.py @@ -0,0 +1,39 @@ +# -*- coding: utf-8 -*- + +import os +from flexmock import flexmock +from orator.migrations import Migrator +from orator.commands.migrations import RefreshCommand +from orator import DatabaseManager +from .. import OratorCommandTestCase + + +class RefreshCommandTestCase(OratorCommandTestCase): + def test_refresh_runs_the_seeder_when_seed_option_set(self): + resolver = flexmock(DatabaseManager) + resolver.should_receive("connection").and_return(None) + + command = flexmock(RefreshCommand()) + command.should_receive("_get_config").and_return({}) + command.should_receive("confirm").and_return(True) + command.should_receive("call").with_args("migrate:reset", object).and_return( + True + ) + command.should_receive("call").with_args("migrate", object).and_return(True) + command.should_receive("_run_seeder") + + self.run_command(command, [("--seed")]) + + def test_refresh_does_not_run_the_seeder_when_seed_option_absent(self): + resolver = flexmock(DatabaseManager) + resolver.should_receive("connection").and_return(None) + + command = flexmock(RefreshCommand()) + command.should_receive("_get_config").and_return({}) + command.should_receive("confirm").and_return(True) + command.should_receive("call").with_args("migrate:reset", object).and_return( + True + ) + command.should_receive("call").with_args("migrate", object).and_return(True) + + self.run_command(command, []) diff --git a/tests/commands/migrations/test_reset_command.py b/tests/commands/migrations/test_reset_command.py index 67cc5216..32581772 100644 --- a/tests/commands/migrations/test_reset_command.py +++ b/tests/commands/migrations/test_reset_command.py @@ -1,68 +1,77 @@ # -*- coding: utf-8 -*- import os -from io import BytesIO from flexmock import flexmock from orator.migrations import Migrator from orator.commands.migrations import ResetCommand from orator import DatabaseManager -from orator.connections import Connection from .. import OratorCommandTestCase class ResetCommandTestCase(OratorCommandTestCase): - def test_basic_migrations_call_migrator_with_proper_arguments(self): resolver = flexmock(DatabaseManager) - resolver.should_receive('connection').and_return(None) + resolver.should_receive("connection").and_return(None) migrator_mock = flexmock(Migrator) - migrator_mock.should_receive('set_connection').once().with_args(None) - migrator_mock.should_receive('reset').once()\ - .with_args(os.path.join(os.getcwd(), 'migrations'), False)\ - .and_return(2) - migrator_mock.should_receive('get_notes').and_return([]) + migrator_mock.should_receive("set_connection").once().with_args(None) + migrator_mock.should_receive("reset").once().with_args( + os.path.join(os.getcwd(), "migrations"), False + ).and_return(2) + migrator_mock.should_receive("get_notes").and_return([]) command = flexmock(ResetCommand()) - command.should_receive('_get_config').and_return({}) + command.should_receive("_get_config").and_return({}) + command.should_receive("confirm").and_return(True) - self.run_command(command, input_stream=self.get_input_stream('y\n')) + self.run_command(command) def test_migration_can_be_pretended(self): resolver = flexmock(DatabaseManager) - resolver.should_receive('connection').and_return(None) + resolver.should_receive("connection").and_return(None) migrator_mock = flexmock(Migrator) - migrator_mock.should_receive('set_connection').once().with_args(None) - migrator_mock.should_receive('reset').once()\ - .with_args(os.path.join(os.getcwd(), 'migrations'), True)\ - .and_return(2) - migrator_mock.should_receive('get_notes').and_return([]) + migrator_mock.should_receive("set_connection").once().with_args(None) + migrator_mock.should_receive("reset").once().with_args( + os.path.join(os.getcwd(), "migrations"), True + ).and_return(2) + migrator_mock.should_receive("get_notes").and_return([]) command = flexmock(ResetCommand()) - command.should_receive('_get_config').and_return({}) + command.should_receive("_get_config").and_return({}) + command.should_receive("confirm").and_return(True) - self.run_command(command, [('--pretend', True)], input_stream=self.get_input_stream('y\n')) + self.run_command(command, [("--pretend", True)]) def test_migration_database_can_be_set(self): resolver = flexmock(DatabaseManager) - resolver.should_receive('connection').and_return(None) + resolver.should_receive("connection").and_return(None) migrator_mock = flexmock(Migrator) - migrator_mock.should_receive('set_connection').once().with_args('foo') - migrator_mock.should_receive('reset').once()\ - .with_args(os.path.join(os.getcwd(), 'migrations'), False)\ - .and_return(2) - migrator_mock.should_receive('get_notes').and_return([]) + migrator_mock.should_receive("set_connection").once().with_args("foo") + migrator_mock.should_receive("reset").once().with_args( + os.path.join(os.getcwd(), "migrations"), False + ).and_return(2) + migrator_mock.should_receive("get_notes").and_return([]) command = flexmock(ResetCommand()) - command.should_receive('_get_config').and_return({}) + command.should_receive("_get_config").and_return({}) + command.should_receive("confirm").and_return(True) + + self.run_command(command, [("--database", "foo")]) + + def test_migration_can_be_forced(self): + resolver = flexmock(DatabaseManager) + resolver.should_receive("connection").and_return(None) - self.run_command(command, [('--database', 'foo')], input_stream=self.get_input_stream('y\n')) + migrator_mock = flexmock(Migrator) + migrator_mock.should_receive("set_connection").once().with_args(None) + migrator_mock.should_receive("reset").once().with_args( + os.path.join(os.getcwd(), "migrations"), False + ).and_return(2) + migrator_mock.should_receive("get_notes").and_return([]) - def get_input_stream(self, input_): - stream = BytesIO() - stream.write(input_.encode()) - stream.seek(0) + command = flexmock(ResetCommand()) + command.should_receive("_get_config").and_return({}) - return stream + self.run_command(command, [("--force", True)]) diff --git a/tests/commands/migrations/test_rollback_command.py b/tests/commands/migrations/test_rollback_command.py index 39973f26..9de77bba 100644 --- a/tests/commands/migrations/test_rollback_command.py +++ b/tests/commands/migrations/test_rollback_command.py @@ -1,9 +1,7 @@ # -*- coding: utf-8 -*- import os -from io import BytesIO from flexmock import flexmock -from cleo import Output from orator.migrations import Migrator from orator.commands.migrations import RollbackCommand from orator import DatabaseManager @@ -11,52 +9,69 @@ class RollbackCommandTestCase(OratorCommandTestCase): - def test_basic_migrations_call_migrator_with_proper_arguments(self): resolver = flexmock(DatabaseManager) - resolver.should_receive('connection').and_return(None) + resolver.should_receive("connection").and_return(None) migrator_mock = flexmock(Migrator) - migrator_mock.should_receive('set_connection').once().with_args(None) - migrator_mock.should_receive('rollback').once().with_args(os.path.join(os.getcwd(), 'migrations'), False) - migrator_mock.should_receive('get_notes').and_return([]) + migrator_mock.should_receive("set_connection").once().with_args(None) + migrator_mock.should_receive("rollback").once().with_args( + os.path.join(os.getcwd(), "migrations"), False + ) + migrator_mock.should_receive("get_notes").and_return([]) command = flexmock(RollbackCommand()) - command.should_receive('_get_config').and_return({}) + command.should_receive("_get_config").and_return({}) + command.should_receive("confirm").and_return(True) - self.run_command(command, input_stream=self.get_input_stream('y\n')) + self.run_command(command) def test_migration_can_be_pretended(self): resolver = flexmock(DatabaseManager) - resolver.should_receive('connection').and_return(None) + resolver.should_receive("connection").and_return(None) migrator_mock = flexmock(Migrator) - migrator_mock.should_receive('set_connection').once().with_args(None) - migrator_mock.should_receive('rollback').once().with_args(os.path.join(os.getcwd(), 'migrations'), True) - migrator_mock.should_receive('get_notes').and_return([]) + migrator_mock.should_receive("set_connection").once().with_args(None) + migrator_mock.should_receive("rollback").once().with_args( + os.path.join(os.getcwd(), "migrations"), True + ) + migrator_mock.should_receive("get_notes").and_return([]) command = flexmock(RollbackCommand()) - command.should_receive('_get_config').and_return({}) + command.should_receive("_get_config").and_return({}) + command.should_receive("confirm").and_return(True) - self.run_command(command, [('--pretend', True)], input_stream=self.get_input_stream('y\n')) + self.run_command(command, [("--pretend", True)]) def test_migration_database_can_be_set(self): resolver = flexmock(DatabaseManager) - resolver.should_receive('connection').and_return(None) + resolver.should_receive("connection").and_return(None) migrator_mock = flexmock(Migrator) - migrator_mock.should_receive('set_connection').once().with_args('foo') - migrator_mock.should_receive('rollback').once().with_args(os.path.join(os.getcwd(), 'migrations'), False) - migrator_mock.should_receive('get_notes').and_return([]) + migrator_mock.should_receive("set_connection").once().with_args("foo") + migrator_mock.should_receive("rollback").once().with_args( + os.path.join(os.getcwd(), "migrations"), False + ) + migrator_mock.should_receive("get_notes").and_return([]) command = flexmock(RollbackCommand()) - command.should_receive('_get_config').and_return({}) + command.should_receive("_get_config").and_return({}) + command.should_receive("confirm").and_return(True) + + self.run_command(command, [("--database", "foo")]) + + def test_migration_can_be_forced(self): + resolver = flexmock(DatabaseManager) + resolver.should_receive("connection").and_return(None) - self.run_command(command, [('--database', 'foo')], input_stream=self.get_input_stream('y\n')) + migrator_mock = flexmock(Migrator) + migrator_mock.should_receive("set_connection").once().with_args(None) + migrator_mock.should_receive("rollback").once().with_args( + os.path.join(os.getcwd(), "migrations"), False + ) + migrator_mock.should_receive("get_notes").and_return([]) - def get_input_stream(self, input_): - stream = BytesIO() - stream.write(input_.encode()) - stream.seek(0) + command = flexmock(RollbackCommand()) + command.should_receive("_get_config").and_return({}) - return stream + self.run_command(command, [("--force", True)]) diff --git a/tests/connections/__init__.py b/tests/connections/__init__.py index 633f8661..40a96afc 100644 --- a/tests/connections/__init__.py +++ b/tests/connections/__init__.py @@ -1,2 +1 @@ # -*- coding: utf-8 -*- - diff --git a/tests/connections/test_connection.py b/tests/connections/test_connection.py index 557a74fd..870ef8fe 100644 --- a/tests/connections/test_connection.py +++ b/tests/connections/test_connection.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- import threading + try: from Queue import Queue except ImportError: @@ -17,24 +18,23 @@ class ConnectionTestCase(OratorTestCase): - def test_table_returns_query_builder(self): - connection = Connection(None, 'database') - builder = connection.table('users') + connection = Connection(None, "database") + builder = connection.table("users") self.assertIsInstance(builder, QueryBuilder) - self.assertEqual('users', builder.from__) + self.assertEqual("users", builder.from__) self.assertEqual(connection.get_query_grammar(), builder.get_grammar()) def test_transaction(self): - connection = Connection(None, 'database') + connection = Connection(None, "database") connection.begin_transaction = mock.MagicMock(unsafe=True) connection.commit = mock.MagicMock(unsafe=True) connection.rollback = mock.MagicMock(unsafe=True) connection.insert = mock.MagicMock(return_value=1) with connection.transaction(): - connection.table('users').insert({'name': 'foo'}) + connection.table("users").insert({"name": "foo"}) connection.begin_transaction.assert_called_once() connection.commit.assert_called_once() @@ -46,29 +46,33 @@ def test_transaction(self): try: with connection.transaction(): - connection.table('users').insert({'name': 'foo'}) - raise Exception('foo') + connection.table("users").insert({"name": "foo"}) + raise Exception("foo") except Exception as e: - self.assertEqual('foo', str(e)) + self.assertEqual("foo", str(e)) connection.begin_transaction.assert_called_once() connection.rollback.assert_called_once() self.assertFalse(connection.commit.called) def test_try_again_if_caused_by_lost_connection_is_called(self): - connection = flexmock(Connection(None, 'database')) + connection = flexmock(Connection(None, "database")) cursor = flexmock() - connection.should_receive('_try_again_if_caused_by_lost_connection').once() - connection.should_receive('_get_cursor_for_select').and_return(cursor) - connection.should_receive('reconnect') - cursor.should_receive('execute').and_raise(Exception('error')) + connection.should_receive("_try_again_if_caused_by_lost_connection").once() + connection.should_receive("_get_cursor_for_select").and_return(cursor) + connection.should_receive("reconnect") + cursor.should_receive("execute").and_raise(Exception("error")) connection.select('SELECT * FROM "users"') + def test_lost_connection_returns_true_with_capitalized_error(self): + connection = Connection(None, "database") + self.assertTrue(connection._caused_by_lost_connection("Lost Connection")) + def test_prefix_set_to_none(self): - connection = Connection(None, 'database', None) + connection = Connection(None, "database", None) self.assertIsNotNone(connection.get_table_prefix()) - self.assertEqual('', connection.get_table_prefix()) + self.assertEqual("", connection.get_table_prefix()) class ConnectionThreadLocalTest(OratorTestCase): @@ -80,13 +84,15 @@ def test_create_thread_local(self): def create_user_thread(low, hi): for _ in range(low, hi): - User.create(name='u%d' % i) + User.create(name="u%d" % i) User.get_connection_resolver().disconnect() threads = [] for i in range(self.threads): - threads.append(threading.Thread(target=create_user_thread, args=(i*10, i * 10 + 10))) + threads.append( + threading.Thread(target=create_user_thread, args=(i * 10, i * 10 + 10)) + ) [t.start() for t in threads] [t.join() for t in threads] @@ -105,7 +111,9 @@ def reader_thread(q, num): threads = [] for i in range(self.threads): - threads.append(threading.Thread(target=reader_thread, args=(data_queue, 20))) + threads.append( + threading.Thread(target=reader_thread, args=(data_queue, 20)) + ) [t.start() for t in threads] [t.join() for t in threads] diff --git a/tests/connections/test_mysql_connection.py b/tests/connections/test_mysql_connection.py index c17566e2..2c9a2f9f 100644 --- a/tests/connections/test_mysql_connection.py +++ b/tests/connections/test_mysql_connection.py @@ -6,18 +6,17 @@ class MySQLConnectionTestCase(OratorTestCase): - def test_marker_is_properly_set(self): - connection = MySQLConnection(None, 'database', '', {'use_qmark': True}) + connection = MySQLConnection(None, "database", "", {"use_qmark": True}) - self.assertEqual('?', connection.get_marker()) + self.assertEqual("?", connection.get_marker()) def test_marker_default(self): - connection = MySQLConnection(None, 'database', '', {}) + connection = MySQLConnection(None, "database", "", {}) self.assertIsNone(connection.get_marker()) def test_marker_use_qmark_false(self): - connection = MySQLConnection(None, 'database', '', {'use_qmark': False}) + connection = MySQLConnection(None, "database", "", {"use_qmark": False}) self.assertIsNone(connection.get_marker()) diff --git a/tests/connections/test_postgres_connection.py b/tests/connections/test_postgres_connection.py index 5225223a..48d62239 100644 --- a/tests/connections/test_postgres_connection.py +++ b/tests/connections/test_postgres_connection.py @@ -6,18 +6,17 @@ class PostgresConnectionTestCase(OratorTestCase): - def test_marker_is_properly_set(self): - connection = PostgresConnection(None, 'database', '', {'use_qmark': True}) + connection = PostgresConnection(None, "database", "", {"use_qmark": True}) - self.assertEqual('?', connection.get_marker()) + self.assertEqual("?", connection.get_marker()) def test_marker_default(self): - connection = PostgresConnection(None, 'database', '', {}) + connection = PostgresConnection(None, "database", "", {}) self.assertIsNone(connection.get_marker()) def test_marker_use_qmark_false(self): - connection = PostgresConnection(None, 'database', '', {'use_qmark': False}) + connection = PostgresConnection(None, "database", "", {"use_qmark": False}) self.assertIsNone(connection.get_marker()) diff --git a/tests/integrations/__init__.py b/tests/integrations/__init__.py index 8a09c847..a59a7f33 100644 --- a/tests/integrations/__init__.py +++ b/tests/integrations/__init__.py @@ -1,18 +1,55 @@ # -*- coding: utf-8 -*- +import os +import json +import logging import pendulum import simplejson as json -from datetime import datetime, timedelta + +from datetime import datetime, timedelta, date from backpack import collect from orator import Model, Collection, DatabaseManager -from orator.orm import morph_to, has_one, has_many, belongs_to_many, morph_many, belongs_to, scope, accessor +from orator.orm import ( + morph_to, + has_one, + has_many, + belongs_to_many, + morph_many, + belongs_to, + scope, + accessor, +) from orator.orm.relations import BelongsToMany from orator.exceptions.orm import ModelNotFound -from orator.exceptions.query import QueryException -class IntegrationTestCase(object): +logger = logging.getLogger("orator.connection.queries") +logger.setLevel(logging.DEBUG) + + +class LoggedQueriesFormatter(logging.Formatter): + def __init__(self, fmt=None, datefmt=None, style="%"): + super(LoggedQueriesFormatter, self).__init__() + + self.logged_queries = [] + + def format(self, record): + self.logged_queries.append(record.query) + + return super(LoggedQueriesFormatter, self).format(record) + + def reset(self): + self.logged_queries = [] + + +formatter = LoggedQueriesFormatter() +handler = logging.StreamHandler(open(os.devnull, "w")) + +handler.setFormatter(formatter) +logger.addHandler(handler) + +class IntegrationTestCase(object): @classmethod def setUpClass(cls): Model.set_connection_resolver(cls.get_connection_resolver()) @@ -26,12 +63,12 @@ def get_connection_resolver(cls): # Adding another connection to test connection switching config = cls.get_manager_config() - config['test'] = { - 'driver': 'sqlite', - 'database': ':memory:' - } + config["test"] = {"driver": "sqlite", "database": ":memory:"} + + db = DatabaseManager(config) + db.connection().enable_query_log() - return DatabaseManager(config) + return db @classmethod def tearDownClass(cls): @@ -43,134 +80,174 @@ def marker(self): def setUp(self): self.migrate() - self.migrate('test') + self.migrate("test") + + formatter.reset() def tearDown(self): self.revert() - self.revert('test') + self.revert("test") def test_basic_model_retrieval(self): - OratorTestUser.create(email='john@doe.com') - model = OratorTestUser.where('email', 'john@doe.com').first() - self.assertEqual('john@doe.com', model.email) + OratorTestUser.create(email="john@doe.com") + model = OratorTestUser.where("email", "john@doe.com").first() + self.assertEqual("john@doe.com", model.email) def test_basic_model_collection_retrieval(self): - OratorTestUser.create(id=1, email='john@doe.com') - OratorTestUser.create(id=2, email='jane@doe.com') + OratorTestUser.create(id=1, email="john@doe.com") + OratorTestUser.create(id=2, email="jane@doe.com") - models = OratorTestUser.oldest('id').get() + models = OratorTestUser.oldest("id").get() self.assertEqual(2, len(models)) self.assertIsInstance(models, Collection) self.assertIsInstance(models[0], OratorTestUser) self.assertIsInstance(models[1], OratorTestUser) - self.assertEqual('john@doe.com', models[0].email) - self.assertEqual('jane@doe.com', models[1].email) + self.assertEqual("john@doe.com", models[0].email) + self.assertEqual("jane@doe.com", models[1].email) def test_lists_retrieval(self): - OratorTestUser.create(id=1, email='john@doe.com') - OratorTestUser.create(id=2, email='jane@doe.com') + OratorTestUser.create(id=1, email="john@doe.com") + OratorTestUser.create(id=2, email="jane@doe.com") - simple = OratorTestUser.oldest('id').lists('email') - keyed = OratorTestUser.oldest('id').lists('email', 'id') + simple = OratorTestUser.oldest("id").lists("email") + keyed = OratorTestUser.oldest("id").lists("email", "id") - self.assertEqual(['john@doe.com', 'jane@doe.com'], simple) - self.assertEqual({1: 'john@doe.com', 2: 'jane@doe.com'}, keyed) + self.assertEqual(["john@doe.com", "jane@doe.com"], simple) + self.assertEqual({1: "john@doe.com", 2: "jane@doe.com"}, keyed) def test_find_or_fail(self): - OratorTestUser.create(id=1, email='john@doe.com') - OratorTestUser.create(id=2, email='jane@doe.com') + OratorTestUser.create(id=1, email="john@doe.com") + OratorTestUser.create(id=2, email="jane@doe.com") single = OratorTestUser.find_or_fail(1) multiple = OratorTestUser.find_or_fail([1, 2]) self.assertIsInstance(single, OratorTestUser) - self.assertEqual('john@doe.com', single.email) + self.assertEqual("john@doe.com", single.email) self.assertIsInstance(multiple, Collection) self.assertIsInstance(multiple[0], OratorTestUser) self.assertIsInstance(multiple[1], OratorTestUser) def test_find_or_fail_with_single_id_raises_model_not_found_exception(self): - self.assertRaises( - ModelNotFound, - OratorTestUser.find_or_fail, - 1 - ) + self.assertRaises(ModelNotFound, OratorTestUser.find_or_fail, 1) def test_find_or_fail_with_multiple_ids_raises_model_not_found_exception(self): - self.assertRaises( - ModelNotFound, - OratorTestUser.find_or_fail, - [1, 2] - ) + self.assertRaises(ModelNotFound, OratorTestUser.find_or_fail, [1, 2]) def test_one_to_one_relationship(self): - user = OratorTestUser.create(email='john@doe.com') - user.post().create(name='First Post') + user = OratorTestUser.create(email="john@doe.com") + user.post().create(name="First Post") post = user.post user = post.user - self.assertEqual('john@doe.com', user.email) - self.assertEqual('First Post', post.name) + self.assertEqual("john@doe.com", user.email) + self.assertEqual("First Post", post.name) def test_one_to_many_relationship(self): - user = OratorTestUser.create(email='john@doe.com') - user.posts().create(name='First Post') - user.posts().create(name='Second Post') + user = OratorTestUser.create(email="john@doe.com") + user.posts().create(name="First Post") + user.posts().create(name="Second Post") posts = user.posts - post2 = user.posts().where('name', 'Second Post').first() + post2 = user.posts().where("name", "Second Post").first() self.assertEqual(2, len(posts)) self.assertIsInstance(posts[0], OratorTestPost) self.assertIsInstance(posts[1], OratorTestPost) self.assertIsInstance(post2, OratorTestPost) - self.assertEqual('Second Post', post2.name) + self.assertEqual("Second Post", post2.name) self.assertIsInstance(post2.user, OratorTestUser) - self.assertEqual('john@doe.com', post2.user.email) + self.assertEqual("john@doe.com", post2.user.email) def test_basic_model_hydrate(self): - OratorTestUser.create(id=1, email='john@doe.com') - OratorTestUser.create(id=2, email='jane@doe.com') + OratorTestUser.create(id=1, email="john@doe.com") + OratorTestUser.create(id=2, email="jane@doe.com") models = OratorTestUser.hydrate_raw( - 'SELECT * FROM test_users WHERE email = %s' % self.marker, - ['jane@doe.com'], - self.connection().get_name() + "SELECT * FROM test_users WHERE email = %s" % self.marker, + ["jane@doe.com"], + self.connection().get_name(), ) self.assertIsInstance(models, Collection) self.assertIsInstance(models[0], OratorTestUser) - self.assertEqual('jane@doe.com', models[0].email) + self.assertEqual("jane@doe.com", models[0].email) self.assertEqual(self.connection().get_name(), models[0].get_connection_name()) self.assertEqual(1, len(models)) def test_has_on_self_referencing_belongs_to_many_relationship(self): - user = OratorTestUser.create(email='john@doe.com') - friend = user.friends().create(email='jane@doe.com') + user = OratorTestUser.create(email="john@doe.com") + friend = user.friends().create(email="jane@doe.com") - results = OratorTestUser.has('friends').get() + results = OratorTestUser.has("friends").get() self.assertEqual(1, len(results)) - self.assertEqual('john@doe.com', results.first().email) + self.assertEqual("john@doe.com", results.first().email) def test_basic_has_many_eager_loading(self): - user = OratorTestUser.create(id=1, email='john@doe.com') - user.posts().create(name='First Post') - user = OratorTestUser.with_('posts').where('email', 'john@doe.com').first() + user = OratorTestUser.create(id=1, email="john@doe.com") + post = user.posts().create(name="First Post") + comment = post.comments().create(body="Text") + comment2 = post.comments().create(body="Text 2") + comment.children().save(comment2) + user = ( + OratorTestUser.with_("posts.comments.children.parent") + .where("email", "john@doe.com") + .first() + ) + + self.assertEqual("First Post", user.posts.first().name) + self.assertEqual("Text", user.posts.first().comments.first().body) + self.assertEqual( + "Text 2", user.posts.first().comments.first().children.first().body + ) + self.assertEqual( + "Text", user.posts.first().comments.first().children.first().parent.body + ) - self.assertEqual('First Post', user.posts.first().name) + queries = formatter.logged_queries + self.assertEqual(10, len(queries)) - post = OratorTestPost.with_('user').where('name', 'First Post').get() - self.assertEqual('john@doe.com', post.first().user.email) + formatter.reset() + + post = OratorTestPost.with_("user").where("name", "First Post").get() + self.assertEqual("john@doe.com", post.first().user.email) + + comment = ( + OratorTestComment.with_("parent.post.user").where("body", "Text 2").first() + ) + self.assertEqual("Text", comment.parent.body) + self.assertEqual("First Post", comment.parent.post.name) + self.assertEqual("john@doe.com", comment.parent.post.user.email) + + queries = formatter.logged_queries + self.assertEqual(6, len(queries)) + + def test_all_eager_loaded_transitive_relations_must_be_present(self): + user = OratorTestUser.create(id=1, email="john@doe.com") + post = user.posts().create(name="First Post") + comment = post.comments().create(body="Text") + comment2 = post.comments().create(body="Text 2") + comment.children().save(comment2) + post = ( + OratorTestPost.with_("user", "user.posts", "user.post") + .where("id", post.id) + .first() + ) + + data = post.serialize() + assert "user" in data + assert "posts" in data["user"] + assert "post" in data["user"] def test_basic_morph_many_relationship(self): - user = OratorTestUser.create(email='john@doe.com') - user.photos().create(name='Avatar 1') - user.photos().create(name='Avatar 2') - post = user.posts().create(name='First Post') - post.photos().create(name='Hero 1') - post.photos().create(name='Hero 2') + user = OratorTestUser.create(email="john@doe.com") + user.photos().create(name="Avatar 1") + user.photos().create(name="Avatar 2") + post = user.posts().create(name="First Post") + post.photos().create(name="Hero 1") + post.photos().create(name="Hero 2") self.assertIsInstance(user.photos, Collection) self.assertIsInstance(user.photos[0], OratorTestPhoto) @@ -179,157 +256,217 @@ def test_basic_morph_many_relationship(self): self.assertIsInstance(post.photos[0], OratorTestPhoto) self.assertEqual(2, len(user.photos)) self.assertEqual(2, len(post.photos)) - self.assertEqual('Avatar 1', user.photos[0].name) - self.assertEqual('Avatar 2', user.photos[1].name) - self.assertEqual('Hero 1', post.photos[0].name) - self.assertEqual('Hero 2', post.photos[1].name) + self.assertEqual("Avatar 1", user.photos[0].name) + self.assertEqual("Avatar 2", user.photos[1].name) + self.assertEqual("Hero 1", post.photos[0].name) + self.assertEqual("Hero 2", post.photos[1].name) - photos = OratorTestPhoto.order_by('name').get() + photos = OratorTestPhoto.order_by("name").get() self.assertIsInstance(photos, Collection) self.assertEqual(4, len(photos)) self.assertIsInstance(photos[0].imageable, OratorTestUser) self.assertIsInstance(photos[2].imageable, OratorTestPost) - self.assertEqual('john@doe.com', photos[1].imageable.email) - self.assertEqual('First Post', photos[3].imageable.name) + self.assertEqual("john@doe.com", photos[1].imageable.email) + self.assertEqual("First Post", photos[3].imageable.name) def test_multi_insert_with_different_values(self): date = pendulum.utcnow()._datetime - result = OratorTestPost.insert([ - { - 'user_id': 1, 'name': 'Post', 'created_at': date, 'updated_at': date - }, { - 'user_id': 2, 'name': 'Post', 'created_at': date, 'updated_at': date - } - ]) + user1 = OratorTestUser.create(email="john@doe.com") + user2 = OratorTestUser.create(email="jane@doe.com") + result = OratorTestPost.insert( + [ + { + "user_id": user1.id, + "name": "Post", + "created_at": date, + "updated_at": date, + }, + { + "user_id": user2.id, + "name": "Post", + "created_at": date, + "updated_at": date, + }, + ] + ) self.assertTrue(result) self.assertEqual(2, OratorTestPost.count()) def test_multi_insert_with_same_values(self): date = pendulum.utcnow()._datetime - result = OratorTestPost.insert([ - { - 'user_id': 1, 'name': 'Post', 'created_at': date, 'updated_at': date - }, { - 'user_id': 1, 'name': 'Post', 'created_at': date, 'updated_at': date - } - ]) + user1 = OratorTestUser.create(email="john@doe.com") + result = OratorTestPost.insert( + [ + { + "user_id": user1.id, + "name": "Post", + "created_at": date, + "updated_at": date, + }, + { + "user_id": user1.id, + "name": "Post", + "created_at": date, + "updated_at": date, + }, + ] + ) self.assertTrue(result) self.assertEqual(2, OratorTestPost.count()) def test_belongs_to_many_further_query(self): - user = OratorTestUser.create(id=1, email='john@doe.com') - friend = OratorTestUser.create(id=2, email='jane@doe.com') - another_friend = OratorTestUser.create(id=3, email='another@doe.com') + user = OratorTestUser.create(id=1, email="john@doe.com") + friend = OratorTestUser.create(id=2, email="jane@doe.com") + another_friend = OratorTestUser.create(id=3, email="another@doe.com") user.friends().attach(friend) user.friends().attach(another_friend) - related_friend = OratorTestUser.with_('friends').find(1).friends().where('test_users.id', 3).first() + related_friend = ( + OratorTestUser.with_("friends") + .find(1) + .friends() + .where("test_users.id", 3) + .first() + ) self.assertEqual(3, related_friend.id) - self.assertEqual('another@doe.com', related_friend.email) - self.assertIn('pivot', related_friend.to_dict()) + self.assertEqual("another@doe.com", related_friend.email) + self.assertIn("pivot", related_friend.to_dict()) self.assertEqual(1, related_friend.pivot.user_id) self.assertEqual(3, related_friend.pivot.friend_id) - self.assertTrue(hasattr(related_friend.pivot, 'id')) + self.assertTrue(hasattr(related_friend.pivot, "is_close")) + + self.assertIsInstance(user.friends().with_pivot("is_close"), BelongsToMany) - self.assertIsInstance(user.friends().with_pivot('id'), BelongsToMany) + self.assertEqual(2, user.friends().get().count()) + user.friends().sync([friend.id]) + self.assertEqual(1, user.friends().get().count()) def test_belongs_to_morph_many_eagerload(self): - user = OratorTestUser.create(id=1, email='john@doe.com') - user.photos().create(name='Avatar 1') - user.photos().create(name='Avatar 2') - post = user.posts().create(name='First Post') - post.photos().create(name='Hero 1') - post.photos().create(name='Hero 2') - - posts = OratorTestPost.with_('user', 'photos').get() + user = OratorTestUser.create(id=1, email="john@doe.com") + user.photos().create(name="Avatar 1") + user.photos().create(name="Avatar 2") + post = user.posts().create(name="First Post") + post.photos().create(name="Hero 1") + post.photos().create(name="Hero 2") + + posts = OratorTestPost.with_("user", "photos").get() self.assertIsInstance(posts[0].user, OratorTestUser) self.assertEqual(user.id, posts[0].user().first().id) self.assertIsInstance(posts[0].photos, Collection) - self.assertEqual(posts[0].photos().where('name', 'Hero 2').first().name, 'Hero 2') + self.assertEqual( + posts[0].photos().where("name", "Hero 2").first().name, "Hero 2" + ) + + def test_belongs_to_associate(self): + user = OratorTestUser.create(id=1, email="john@doe.com") + post = OratorTestPost(name="Test Post") + + post.user().associate(user) + post.save() + + self.assertEqual(1, post.user.id) + + def test_belongs_to_associate_new_instances(self): + user = OratorTestUser.create(email="john@doe.com") + post = user.posts().create(name="First Post") + comment1 = OratorTestComment.create(body="test1", post_id=post.id) + + self.assertEqual(comment1.parent, None) + + comment2 = OratorTestComment.create(body="test2", post_id=post.id) + comment2.parent().associate(comment1) + + self.assertEqual(comment2.parent.id, comment1.id) def test_has_many_eagerload(self): - user = OratorTestUser.create(id=1, email='john@doe.com') - post1 = user.posts().create(name='First Post') - post2 = user.posts().create(name='Second Post') + user = OratorTestUser.create(id=1, email="john@doe.com") + post1 = user.posts().create(name="First Post") + post2 = user.posts().create(name="Second Post") - user = OratorTestUser.with_('posts').first() + user = OratorTestUser.with_("posts").first() self.assertIsInstance(user.posts, Collection) - self.assertEqual(user.posts().where('name', 'Second Post').first().id, post2.id) + self.assertEqual(user.posts().where("name", "Second Post").first().id, post2.id) def test_relationships_properties_accept_builder(self): - user = OratorTestUser.create(id=1, email='john@doe.com') - post1 = user.posts().create(name='First Post') - post2 = user.posts().create(name='Second Post') + user = OratorTestUser.create(id=1, email="john@doe.com") + post1 = user.posts().create(name="First Post") + post2 = user.posts().create(name="Second Post") - user = OratorTestUser.with_('posts').first() - columns = ', '.join(self.connection().get_query_grammar().wrap_list(['id', 'name', 'user_id'])) + user = OratorTestUser.with_("posts").first() + columns = ", ".join( + self.connection().get_query_grammar().wrap_list(["id", "name", "user_id"]) + ) self.assertEqual( - 'SELECT %(columns)s FROM %(table)s WHERE %(table)s.%(user_id)s = %(marker)s ORDER BY %(name)s DESC' + "SELECT %(columns)s FROM %(table)s WHERE %(table)s.%(user_id)s = %(marker)s ORDER BY %(name)s DESC" % { - 'columns': columns, - 'marker': self.marker, - 'table': self.grammar().wrap('test_posts'), - 'user_id': self.grammar().wrap('user_id'), - 'name': self.grammar().wrap('name') + "columns": columns, + "marker": self.marker, + "table": self.grammar().wrap("test_posts"), + "user_id": self.grammar().wrap("user_id"), + "name": self.grammar().wrap("name"), }, - user.post().to_sql() + user.post().to_sql(), ) user = OratorTestUser.first() self.assertEqual( - 'SELECT %(columns)s FROM %(table)s WHERE %(table)s.%(user_id)s = %(marker)s ORDER BY %(name)s DESC' + "SELECT %(columns)s FROM %(table)s WHERE %(table)s.%(user_id)s = %(marker)s ORDER BY %(name)s DESC" % { - 'columns': columns, - 'marker': self.marker, - 'table': self.grammar().wrap('test_posts'), - 'user_id': self.grammar().wrap('user_id'), - 'name': self.grammar().wrap('name') + "columns": columns, + "marker": self.marker, + "table": self.grammar().wrap("test_posts"), + "user_id": self.grammar().wrap("user_id"), + "name": self.grammar().wrap("name"), }, - user.post().to_sql() + user.post().to_sql(), ) def test_morph_to_eagerload(self): - user = OratorTestUser.create(id=1, email='john@doe.com') - user.photos().create(name='Avatar 1') - user.photos().create(name='Avatar 2') - post = user.posts().create(name='First Post') - post.photos().create(name='Hero 1') - post.photos().create(name='Hero 2') - - photo = OratorTestPhoto.with_('imageable').where('name', 'Hero 2').first() + user = OratorTestUser.create(id=1, email="john@doe.com") + user.photos().create(name="Avatar 1") + user.photos().create(name="Avatar 2") + post = user.posts().create(name="First Post") + post.photos().create(name="Hero 1") + post.photos().create(name="Hero 2") + + photo = OratorTestPhoto.with_("imageable").where("name", "Hero 2").first() self.assertIsInstance(photo.imageable, OratorTestPost) self.assertEqual(post.id, photo.imageable.id) - self.assertEqual(post.id, photo.imageable().where('name', 'First Post').first().id) + self.assertEqual( + post.id, photo.imageable().where("name", "First Post").first().id + ) def test_json_type(self): - user = OratorTestUser.create(id=1, email='john@doe.com') - photo = user.photos().create(name='Avatar 1', metadata={'foo': 'bar'}) + user = OratorTestUser.create(id=1, email="john@doe.com") + photo = user.photos().create(name="Avatar 1", metadata={"foo": "bar"}) photo = OratorTestPhoto.find(photo.id) - self.assertEqual('bar', photo.metadata['foo']) + self.assertEqual("bar", photo.metadata["foo"]) def test_local_scopes(self): - yesterday = created_at=datetime.utcnow() - timedelta(days=1) - john = OratorTestUser.create(id=1, email='john@doe.com', created_at=yesterday, updated_at=yesterday) - jane = OratorTestUser.create(id=2, email='jane@doe.com') + yesterday = datetime.utcnow() - timedelta(days=1) + john = OratorTestUser.create( + id=1, email="john@doe.com", created_at=yesterday, updated_at=yesterday + ) + jane = OratorTestUser.create(id=2, email="jane@doe.com") result = OratorTestUser.older_than(minutes=30).get() self.assertEqual(1, len(result)) - self.assertEqual('john@doe.com', result.first().email) + self.assertEqual("john@doe.com", result.first().email) - result = OratorTestUser.where_not_null('id').older_than(minutes=30).get() + result = OratorTestUser.where_not_null("id").older_than(minutes=30).get() self.assertEqual(1, len(result)) - self.assertEqual('john@doe.com', result.first().email) + self.assertEqual("john@doe.com", result.first().email) def test_repr_relations(self): - user = OratorTestUser.create(id=1, email='john@doe.com') - photo = user.photos().create(name='Avatar 1', metadata={'foo': 'bar'}) + user = OratorTestUser.create(id=1, email="john@doe.com") + photo = user.photos().create(name="Avatar 1", metadata={"foo": "bar"}) repr(OratorTestUser.first().photos) - repr(OratorTestUser.with_('photos').first().photos) + repr(OratorTestUser.with_("photos").first().photos) def test_reconnection(self): db = Model.get_connection_resolver() @@ -340,65 +477,83 @@ def test_reconnection(self): db.disconnect() def test_raw_query(self): - user = OratorTestUser.create(id=1, email='john@doe.com') - photo = user.photos().create(name='Avatar 1', metadata={'foo': 'bar'}) + user = OratorTestUser.create(id=1, email="john@doe.com") + photo = user.photos().create(name="Avatar 1", metadata={"foo": "bar"}) - user = self.connection().table('test_users')\ - .where_raw('test_users.email = %s' % self.get_marker(), 'john@doe.com')\ + user = ( + self.connection() + .table("test_users") + .where_raw("test_users.email = %s" % self.get_marker(), "john@doe.com") .first() + ) - self.assertEqual(1, user['id']) + self.assertEqual(1, user["id"]) photos = self.connection().select( - 'SELECT * FROM test_photos WHERE imageable_id = %(marker)s and imageable_type = %(marker)s' + "SELECT * FROM test_photos WHERE imageable_id = %(marker)s and imageable_type = %(marker)s" % {"marker": self.get_marker()}, - [str(user['id']), 'test_users'] + [str(user["id"]), "test_users"], ) - self.assertEqual('Avatar 1', photos[0]['name']) + self.assertEqual("Avatar 1", photos[0]["name"]) def test_pivot(self): - user = OratorTestUser.create(id=1, email='john@doe.com') - friend = OratorTestUser.create(id=2, email='jane@doe.com') - another_friend = OratorTestUser.create(id=3, email='another@doe.com') + user = OratorTestUser.create(id=1, email="john@doe.com") + friend = OratorTestUser.create(id=2, email="jane@doe.com") + another_friend = OratorTestUser.create(id=3, email="another@doe.com") user.friends().attach(friend) user.friends().attach(another_friend) - user.friends().update_existing_pivot(friend.id, {'is_close': True}) - self.assertTrue(user.friends().where('test_users.email', 'jane@doe.com').first().pivot.is_close) + user.friends().update_existing_pivot(friend.id, {"is_close": True}) + self.assertTrue( + user.friends() + .where("test_users.email", "jane@doe.com") + .first() + .pivot.is_close + ) def test_serialization(self): - user = OratorTestUser.create(id=1, email='john@doe.com') - photo = user.photos().create(name='Avatar 1', metadata={'foo': 'bar'}) + user = OratorTestUser.create(id=1, email="john@doe.com") + photo = user.photos().create(name="Avatar 1", metadata={"foo": "bar"}) serialized_user = OratorTestUser.first().serialize() serialized_photo = OratorTestPhoto.first().serialize() - self.assertEqual(1, serialized_user['id']) - self.assertEqual('john@doe.com', serialized_user['email']) - self.assertEqual('Avatar 1', serialized_photo['name']) - self.assertEqual('bar', serialized_photo['metadata']['foo']) - self.assertEqual('Avatar 1', json.loads(OratorTestPhoto.first().to_json())['name']) + self.assertEqual(1, serialized_user["id"]) + self.assertEqual("john@doe.com", serialized_user["email"]) + self.assertEqual("Avatar 1", serialized_photo["name"]) + self.assertEqual("bar", serialized_photo["metadata"]["foo"]) + self.assertEqual( + "Avatar 1", json.loads(OratorTestPhoto.first().to_json())["name"] + ) def test_query_builder_results_attribute_retrieval(self): - user = OratorTestUser.create(id=1, email='john@doe.com') - users = self.connection().table('test_users').get() + user = OratorTestUser.create(id=1, email="john@doe.com") + users = self.connection().table("test_users").get() - self.assertEqual('john@doe.com', users[0].email) - self.assertEqual('john@doe.com', users[0]['email']) + self.assertEqual("john@doe.com", users[0].email) + self.assertEqual("john@doe.com", users[0]["email"]) self.assertEqual(1, users[0].id) - self.assertEqual(1, users[0]['id']) + self.assertEqual(1, users[0]["id"]) + + def test_query_builder_results_serialization(self): + OratorTestUser.create(id=1, email="john@doe.com") + users = self.connection().table("test_users").get() + + serialized = json.loads(users.to_json())[0] + self.assertEqual(1, serialized["id"]) + self.assertEqual("john@doe.com", serialized["email"]) def test_connection_switching(self): - OratorTestUser.create(id=1, email='john@doe.com') + OratorTestUser.create(id=1, email="john@doe.com") - self.assertIsNone(OratorTestUser.on('test').first()) + self.assertIsNone(OratorTestUser.on("test").first()) self.assertIsNotNone(OratorTestUser.first()) - OratorTestUser.on('test').insert(id=1, email='jane@doe.com') - user = OratorTestUser.on('test').first() + OratorTestUser.on("test").insert(id=1, email="jane@doe.com") + user = OratorTestUser.on("test").first() connection = user.get_connection() - post = user.posts().create(name='Test') + post = user.posts().create(name="Test") self.assertEqual(connection, post.get_connection()) def test_columns_listing(self): @@ -408,11 +563,101 @@ def test_columns_listing(self): .all() ) - self.assertEqual(['created_at', 'email', 'id', 'updated_at'], column_names) + self.assertEqual(["created_at", "email", "id", "updated_at"], column_names) def test_has_column(self): + self.assertTrue(self.schema().has_column(OratorTestUser().get_table(), "email")) + + def test_table_exists(self): + self.assertTrue(self.schema().has_table(OratorTestUser().get_table())) + + def test_transaction(self): + count = self.connection().table("test_users").count() + + with self.connection().transaction(): + OratorTestUser.create(id=1, email="jane@doe.com") + self.connection().rollback() + + self.assertEqual(count, self.connection().table("test_users").count()) + + def test_date(self): + user = OratorTestUser.create(id=1, email="john@doe.com") + photo1 = user.photos().create(name="Photo 1", taken_on=pendulum.date.today()) + photo2 = user.photos().create(name="Photo 2") + + self.assertIsInstance(OratorTestPhoto.find(photo1.id).taken_on, date) + self.assertIsNone(OratorTestPhoto.find(photo2.id).taken_on) + + def test_chunk_update_builder(self): + for i in range(20): + self.connection().table("test_users").insert( + id=i + 1, email="john{}@doe.com".format(i) + ) + + count = 0 + for users in ( + self.connection().table("test_users").where("id", "<", 50).chunk(10) + ): + for user in users: + count += 1 + + if count == 10: + self.connection().table("test_users").where("id", user.id).update( + id=60 + ) + + self.assertEqual(count, 20) + + def test_chunk_update_model(self): + for i in range(20): + OratorTestUser.create(id=i + 1, email="john{}@doe.com".format(i)) + + count = 0 + for users in OratorTestUser.where("id", "<", 50).chunk(10): + for user in users: + count += 1 + + if count == 10: + OratorTestUser.where("id", user.id).update(id=60) + + self.assertEqual(count, 20) + + def test_timestamp_with_timezone(self): + now = pendulum.utcnow() + user = OratorTestUser.create(email="john@doe.com", created_at=now) + fresh_user = OratorTestUser.find(user.id) + + self.assertEqual(user.created_at, fresh_user.created_at) + self.assertEqual(now, fresh_user.created_at) + + def test_touches(self): + user = OratorTestUser.create(email="john@doe.com") + post = user.posts().create(name="Post") + comment1 = post.comments().create(body="Comment 1") + comment2 = post.comments().create(body="Comment 2") + comment3 = post.comments().create(body="Comment 3") + comment4 = comment3.children().create(body="Comment 4", post_id=post.id) + + comment1_updated_at = comment1.updated_at + comment2_updated_at = comment2.updated_at + comment3_updated_at = comment3.updated_at + comment4_updated_at = comment4.updated_at + + comment4.body = "Comment 4 updated" + comment4.save() + + self.assertTrue(comment4.updated_at > comment4_updated_at) + self.assertEqual( + comment4.updated_at, OratorTestComment.find(comment4.id).updated_at + ) self.assertTrue( - self.schema().has_column(OratorTestUser().get_table(), 'email') + comment3_updated_at < OratorTestComment.find(comment3.id).updated_at + ) + self.assertEqual( + comment1_updated_at, OratorTestComment.find(comment1.id).updated_at + ) + self.assertEqual( + comment2_updated_at, OratorTestComment.find(comment2.id).updated_at ) def grammar(self): @@ -425,93 +670,146 @@ def schema(self, connection=None): return self.connection(connection).get_schema_builder() def migrate(self, connection=None): - self.schema(connection).drop_if_exists('test_users') - self.schema(connection).drop_if_exists('test_friends') - self.schema(connection).drop_if_exists('test_posts') - self.schema(connection).drop_if_exists('test_photos') - - with self.schema(connection).create('test_users') as table: - table.increments('id') - table.string('email').unique() + self.schema(connection).drop_if_exists("test_users") + self.schema(connection).drop_if_exists("test_friends") + self.schema(connection).drop_if_exists("test_posts") + self.schema(connection).drop_if_exists("test_photos") + + with self.schema(connection).create("test_users") as table: + table.increments("id") + table.string("email").unique() + table.timestamps(use_current=True) + + with self.schema(connection).create("test_friends") as table: + table.increments("id") + table.integer("user_id").unsigned() + table.integer("friend_id").unsigned() + table.boolean("is_close").default(False) + + table.foreign("user_id").references("id").on("test_users").on_delete( + "cascade" + ) + table.foreign("friend_id").references("id").on("test_users").on_delete( + "cascade" + ) + + with self.schema(connection).create("test_posts") as table: + table.increments("id") + table.integer("user_id").unsigned() + table.string("name") table.timestamps(use_current=True) - with self.schema(connection).create('test_friends') as table: - table.increments('id') - table.integer('user_id') - table.integer('friend_id') - table.boolean('is_close').default(False) + table.foreign("user_id").references("id").on("test_users").on_delete( + "cascade" + ) - with self.schema(connection).create('test_posts') as table: - table.increments('id') - table.integer('user_id') - table.string('name') + with self.schema(connection).create("test_comments") as table: + table.increments("id") + table.integer("post_id").unsigned() + table.integer("parent_id").unsigned().nullable() + table.text("body") table.timestamps(use_current=True) - with self.schema(connection).create('test_photos') as table: - table.increments('id') - table.morphs('imageable') - table.string('name') - table.json('metadata').nullable() + table.foreign("post_id").references("id").on("test_posts").on_delete( + "cascade" + ) + table.foreign("parent_id").references("id").on("test_comments").on_delete( + "cascade" + ) + + with self.schema(connection).create("test_photos") as table: + table.increments("id") + table.morphs("imageable") + table.string("name") + table.json("metadata").nullable() + table.date("taken_on").nullable() table.timestamps(use_current=True) def revert(self, connection=None): - self.schema(connection).drop_if_exists('test_users') - self.schema(connection).drop_if_exists('test_friends') - self.schema(connection).drop_if_exists('test_posts') - self.schema(connection).drop_if_exists('test_photos') + self.schema(connection).drop_if_exists("test_photos") + self.schema(connection).drop_if_exists("test_comments") + self.schema(connection).drop_if_exists("test_posts") + self.schema(connection).drop_if_exists("test_friends") + self.schema(connection).drop_if_exists("test_users") def get_marker(self): - return '?' + return "?" class OratorTestUser(Model): - __table__ = 'test_users' + __table__ = "test_users" __guarded__ = [] - @belongs_to_many('test_friends', 'user_id', 'friend_id', with_pivot=['id', 'is_close']) + @belongs_to_many("test_friends", "user_id", "friend_id", with_pivot=["is_close"]) def friends(self): return OratorTestUser - @has_many('user_id') + @has_many("user_id") def posts(self): - return 'test_posts' + return "test_posts" - @has_one('user_id') + @has_one("user_id") def post(self): - return OratorTestPost.select('id', 'name', 'name', 'user_id').order_by('name', 'desc') + return OratorTestPost.select("id", "name", "name", "user_id").order_by( + "name", "desc" + ) - @morph_many('imageable') + @morph_many("imageable") def photos(self): - return OratorTestPhoto.order_by('name') + return OratorTestPhoto.order_by("name") @scope def older_than(self, query, **kwargs): - query.where('updated_at', '<', (pendulum.utcnow() - timedelta(**kwargs))._datetime) + query.where("updated_at", "<", pendulum.utcnow().subtract(**kwargs)) class OratorTestPost(Model): - __table__ = 'test_posts' + __table__ = "test_posts" __guarded__ = [] - @belongs_to('user_id') + @belongs_to("user_id") def user(self): return OratorTestUser - @morph_many('imageable') + @has_many("post_id") + def comments(self): + return OratorTestComment + + @morph_many("imageable") def photos(self): - return OratorTestPhoto.order_by('name') + return OratorTestPhoto.order_by("name") + + +class OratorTestComment(Model): + + __touches__ = ["parent"] + + __table__ = "test_comments" + __guarded__ = [] + + @belongs_to("post_id") + def post(self): + return OratorTestPost + + @belongs_to("parent_id") + def parent(self): + return OratorTestComment + + @has_many("parent_id") + def children(self): + return OratorTestComment class OratorTestPhoto(Model): - __table__ = 'test_photos' + __table__ = "test_photos" __guarded__ = [] - __casts__ = { - 'metadata': 'json' - } + __casts__ = {"metadata": "json"} + + __dates__ = ["taken_on"] @morph_to def imageable(self): @@ -519,4 +817,4 @@ def imageable(self): @accessor def created_at(self): - return pendulum.instance(self._attributes['created_at']).to('Europe/Paris') + return pendulum.instance(self._attributes["created_at"]).in_tz("Europe/Paris") diff --git a/tests/integrations/test_mysql.py b/tests/integrations/test_mysql.py index 81bf8f5e..de945367 100644 --- a/tests/integrations/test_mysql.py +++ b/tests/integrations/test_mysql.py @@ -7,29 +7,28 @@ class MySQLIntegrationTestCase(IntegrationTestCase, OratorTestCase): - @classmethod def get_manager_config(cls): - ci = os.environ.get('CI', False) + ci = os.environ.get("CI", False) if ci: - database = 'orator_test' - user = 'root' - password = '' + database = "orator_test" + user = "root" + password = "" else: - database = 'orator_test' - user = 'orator' - password = 'orator' + database = "orator_test" + user = "orator" + password = "orator" return { - 'default': 'mysql', - 'mysql': { - 'driver': 'mysql', - 'database': database, - 'user': user, - 'password': password - } + "default": "mysql", + "mysql": { + "driver": "mysql", + "database": database, + "user": user, + "password": password, + }, } def get_marker(self): - return '%s' + return "%s" diff --git a/tests/integrations/test_mysql_qmark.py b/tests/integrations/test_mysql_qmark.py index 78553848..497a793b 100644 --- a/tests/integrations/test_mysql_qmark.py +++ b/tests/integrations/test_mysql_qmark.py @@ -7,30 +7,29 @@ class MySQLQmarkIntegrationTestCase(IntegrationTestCase, OratorTestCase): - @classmethod def get_manager_config(cls): - ci = os.environ.get('CI', False) + ci = os.environ.get("CI", False) if ci: - database = 'orator_test' - user = 'root' - password = '' + database = "orator_test" + user = "root" + password = "" else: - database = 'orator_test' - user = 'orator' - password = 'orator' + database = "orator_test" + user = "orator" + password = "orator" return { - 'default': 'mysql', - 'mysql': { - 'driver': 'mysql', - 'database': database, - 'user': user, - 'password': password, - 'use_qmark': True - } + "default": "mysql", + "mysql": { + "driver": "mysql", + "database": database, + "user": user, + "password": password, + "use_qmark": True, + }, } def get_marker(self): - return '?' + return "?" diff --git a/tests/integrations/test_postgres.py b/tests/integrations/test_postgres.py index 03886acf..f4c1c886 100644 --- a/tests/integrations/test_postgres.py +++ b/tests/integrations/test_postgres.py @@ -7,29 +7,28 @@ class PostgresIntegrationTestCase(IntegrationTestCase, OratorTestCase): - @classmethod def get_manager_config(cls): - ci = os.environ.get('CI', False) + ci = os.environ.get("CI", False) if ci: - database = 'orator_test' - user = 'postgres' + database = "orator_test" + user = "postgres" password = None else: - database = 'orator_test' - user = 'orator' - password = 'orator' + database = "orator_test" + user = "orator" + password = "orator" return { - 'default': 'postgres', - 'postgres': { - 'driver': 'pgsql', - 'database': database, - 'user': user, - 'password': password - } + "default": "postgres", + "postgres": { + "driver": "pgsql", + "database": database, + "user": user, + "password": password, + }, } def get_marker(self): - return '%s' + return "%s" diff --git a/tests/integrations/test_postgres_qmark.py b/tests/integrations/test_postgres_qmark.py index 5f912322..25b038e2 100644 --- a/tests/integrations/test_postgres_qmark.py +++ b/tests/integrations/test_postgres_qmark.py @@ -7,30 +7,29 @@ class PostgresQmarkIntegrationTestCase(IntegrationTestCase, OratorTestCase): - @classmethod def get_manager_config(cls): - ci = os.environ.get('CI', False) + ci = os.environ.get("CI", False) if ci: - database = 'orator_test' - user = 'postgres' + database = "orator_test" + user = "postgres" password = None else: - database = 'orator_test' - user = 'orator' - password = 'orator' + database = "orator_test" + user = "orator" + password = "orator" return { - 'default': 'postgres', - 'postgres': { - 'driver': 'pgsql', - 'database': database, - 'user': user, - 'password': password, - 'use_qmark': True - } + "default": "postgres", + "postgres": { + "driver": "pgsql", + "database": database, + "user": user, + "password": password, + "use_qmark": True, + }, } def get_marker(self): - return '?' + return "?" diff --git a/tests/integrations/test_sqlite.py b/tests/integrations/test_sqlite.py index a74615e6..116929d8 100644 --- a/tests/integrations/test_sqlite.py +++ b/tests/integrations/test_sqlite.py @@ -5,13 +5,9 @@ class SQLiteIntegrationTestCase(IntegrationTestCase, OratorTestCase): - @classmethod def get_manager_config(cls): return { - 'default': 'sqlite', - 'sqlite': { - 'driver': 'sqlite', - 'database': ':memory:' - } + "default": "sqlite", + "sqlite": {"driver": "sqlite", "database": ":memory:"}, } diff --git a/tests/migrations/__init__.py b/tests/migrations/__init__.py index 633f8661..40a96afc 100644 --- a/tests/migrations/__init__.py +++ b/tests/migrations/__init__.py @@ -1,2 +1 @@ # -*- coding: utf-8 -*- - diff --git a/tests/migrations/test_database_migration_repository.py b/tests/migrations/test_database_migration_repository.py index c7a6fa39..6c7faca6 100644 --- a/tests/migrations/test_database_migration_repository.py +++ b/tests/migrations/test_database_migration_repository.py @@ -10,7 +10,6 @@ class DatabaseMigrationRepositoryTestCase(OratorTestCase): - def tearDown(self): flexmock_teardown() @@ -18,50 +17,70 @@ def test_get_ran_migrations_list_migrations_by_package(self): repo = self.get_repository() connection = flexmock(Connection(None)) query = flexmock(QueryBuilder(connection, None, None)) - repo.get_connection_resolver().should_receive('connection').with_args(None).and_return(connection) - repo.get_connection().should_receive('table').once().with_args('migrations').and_return(query) - query.should_receive('lists').once().with_args('migration').and_return('bar') + repo.get_connection_resolver().should_receive("connection").with_args( + None + ).and_return(connection) + repo.get_connection().should_receive("table").once().with_args( + "migrations" + ).and_return(query) + query.should_receive("lists").once().with_args("migration").and_return("bar") - self.assertEqual('bar', repo.get_ran()) + self.assertEqual("bar", repo.get_ran()) def test_get_last_migrations(self): resolver_mock = flexmock(DatabaseManager) - resolver_mock.should_receive('connection').and_return(None) + resolver_mock.should_receive("connection").and_return(None) resolver = flexmock(resolver_mock({})) - repo = flexmock(DatabaseMigrationRepository(resolver, 'migrations')) + repo = flexmock(DatabaseMigrationRepository(resolver, "migrations")) connection = flexmock(Connection(None)) query = flexmock(QueryBuilder(connection, None, None)) - repo.should_receive('get_last_batch_number').and_return(1) - repo.get_connection_resolver().should_receive('connection').with_args(None).and_return(connection) - repo.get_connection().should_receive('table').once().with_args('migrations').and_return(query) - query.should_receive('where').once().with_args('batch', 1).and_return(query) - query.should_receive('order_by').once().with_args('migration', 'desc').and_return(query) - query.should_receive('get').once().and_return('foo') - - self.assertEqual('foo', repo.get_last()) + repo.should_receive("get_last_batch_number").and_return(1) + repo.get_connection_resolver().should_receive("connection").with_args( + None + ).and_return(connection) + repo.get_connection().should_receive("table").once().with_args( + "migrations" + ).and_return(query) + query.should_receive("where").once().with_args("batch", 1).and_return(query) + query.should_receive("order_by").once().with_args( + "migration", "desc" + ).and_return(query) + query.should_receive("get").once().and_return("foo") + + self.assertEqual("foo", repo.get_last()) def test_log_inserts_record_into_migration_table(self): repo = self.get_repository() connection = flexmock(Connection(None)) query = flexmock(QueryBuilder(connection, None, None)) - repo.get_connection_resolver().should_receive('connection').with_args(None).and_return(connection) - repo.get_connection().should_receive('table').once().with_args('migrations').and_return(query) - query.should_receive('insert').once().with_args(migration='bar', batch=1) + repo.get_connection_resolver().should_receive("connection").with_args( + None + ).and_return(connection) + repo.get_connection().should_receive("table").once().with_args( + "migrations" + ).and_return(query) + query.should_receive("insert").once().with_args(migration="bar", batch=1) - repo.log('bar', 1) + repo.log("bar", 1) def test_delete_removes_migration_from_table(self): repo = self.get_repository() connection = flexmock(Connection(None)) query = flexmock(QueryBuilder(connection, None, None)) - repo.get_connection_resolver().should_receive('connection').with_args(None).and_return(connection) - repo.get_connection().should_receive('table').once().with_args('migrations').and_return(query) - query.should_receive('where').once().with_args('migration', 'foo').and_return(query) - query.should_receive('delete').once() + repo.get_connection_resolver().should_receive("connection").with_args( + None + ).and_return(connection) + repo.get_connection().should_receive("table").once().with_args( + "migrations" + ).and_return(query) + query.should_receive("where").once().with_args("migration", "foo").and_return( + query + ) + query.should_receive("delete").once() class Migration(object): - migration = 'foo' + migration = "foo" def __getitem__(self, item): return self.migration @@ -70,10 +89,10 @@ def __getitem__(self, item): def test_get_next_batch_number_returns_last_batch_number_plus_one(self): resolver_mock = flexmock(DatabaseManager) - resolver_mock.should_receive('connection').and_return(None) + resolver_mock.should_receive("connection").and_return(None) resolver = flexmock(resolver_mock({})) - repo = flexmock(DatabaseMigrationRepository(resolver, 'migrations')) - repo.should_receive('get_last_batch_number').and_return(1) + repo = flexmock(DatabaseMigrationRepository(resolver, "migrations")) + repo.should_receive("get_last_batch_number").and_return(1) self.assertEqual(2, repo.get_next_batch_number()) @@ -81,13 +100,17 @@ def test_get_last_batch_number_returns_max_batch(self): repo = self.get_repository() connection = flexmock(Connection(None)) query = flexmock(QueryBuilder(connection, None, None)) - repo.get_connection_resolver().should_receive('connection').with_args(None).and_return(connection) - repo.get_connection().should_receive('table').once().with_args('migrations').and_return(query) - query.should_receive('max').and_return(1) + repo.get_connection_resolver().should_receive("connection").with_args( + None + ).and_return(connection) + repo.get_connection().should_receive("table").once().with_args( + "migrations" + ).and_return(query) + query.should_receive("max").and_return(1) self.assertEqual(1, repo.get_last_batch_number()) def get_repository(self): resolver = flexmock(DatabaseManager) - resolver.should_receive('connection').and_return(None) - return DatabaseMigrationRepository(flexmock(resolver({})), 'migrations') + resolver.should_receive("connection").and_return(None) + return DatabaseMigrationRepository(flexmock(resolver({})), "migrations") diff --git a/tests/migrations/test_migration_creator.py b/tests/migrations/test_migration_creator.py index 9386b841..c9eebdf3 100644 --- a/tests/migrations/test_migration_creator.py +++ b/tests/migrations/test_migration_creator.py @@ -9,55 +9,64 @@ class MigrationCreatorTestCase(OratorTestCase): - def tearDown(self): flexmock_teardown() def test_basic_create_method_stores_migration_file(self): - expected = os.path.join(tempfile.gettempdir(), 'foo_create_bar.py') + expected = os.path.join(tempfile.gettempdir(), "foo_create_bar.py") if os.path.exists(expected): os.remove(expected) creator = self.get_creator() - creator.should_receive('_get_date_prefix').and_return('foo') - creator.create('create_bar', tempfile.gettempdir()) + creator.should_receive("_get_date_prefix").and_return("foo") + creator.create("create_bar", tempfile.gettempdir()) self.assertTrue(os.path.exists(expected)) with open(expected) as fh: content = fh.read() - self.assertEqual(content, BLANK_STUB.replace('DummyClass', 'CreateBar')) + self.assertEqual(content, BLANK_STUB.replace("DummyClass", "CreateBar")) os.remove(expected) def test_table_update_migration_stores_migration_file(self): - expected = os.path.join(tempfile.gettempdir(), 'foo_create_bar.py') + expected = os.path.join(tempfile.gettempdir(), "foo_create_bar.py") if os.path.exists(expected): os.remove(expected) creator = self.get_creator() - creator.should_receive('_get_date_prefix').and_return('foo') - creator.create('create_bar', tempfile.gettempdir(), 'baz') + creator.should_receive("_get_date_prefix").and_return("foo") + creator.create("create_bar", tempfile.gettempdir(), "baz") self.assertTrue(os.path.exists(expected)) with open(expected) as fh: content = fh.read() - self.assertEqual(content, UPDATE_STUB.replace('DummyClass', 'CreateBar').replace('dummy_table', 'baz')) + self.assertEqual( + content, + UPDATE_STUB.replace("DummyClass", "CreateBar").replace( + "dummy_table", "baz" + ), + ) os.remove(expected) def test_table_create_migration_stores_migration_file(self): - expected = os.path.join(tempfile.gettempdir(), 'foo_create_bar.py') + expected = os.path.join(tempfile.gettempdir(), "foo_create_bar.py") if os.path.exists(expected): os.remove(expected) creator = self.get_creator() - creator.should_receive('_get_date_prefix').and_return('foo') - creator.create('create_bar', tempfile.gettempdir(), 'baz', True) + creator.should_receive("_get_date_prefix").and_return("foo") + creator.create("create_bar", tempfile.gettempdir(), "baz", True) self.assertTrue(os.path.exists(expected)) with open(expected) as fh: content = fh.read() - self.assertEqual(content, CREATE_STUB.replace('DummyClass', 'CreateBar').replace('dummy_table', 'baz')) + self.assertEqual( + content, + CREATE_STUB.replace("DummyClass", "CreateBar").replace( + "dummy_table", "baz" + ), + ) os.remove(expected) diff --git a/tests/migrations/test_migrator.py b/tests/migrations/test_migrator.py index e7d7dfe0..c62c3d2d 100644 --- a/tests/migrations/test_migrator.py +++ b/tests/migrations/test_migrator.py @@ -12,7 +12,6 @@ class MigratorTestCase(OratorTestCase): - def setUp(self): if PY3K: self.orig = inspect.getargspec @@ -26,271 +25,278 @@ def tearDown(self): def test_migrations_are_run_up_when_outstanding_migrations_exist(self): resolver_mock = flexmock(DatabaseManager) - resolver_mock.should_receive('connection').and_return({}) + resolver_mock.should_receive("connection").and_return({}) resolver = flexmock(DatabaseManager({})) connection = flexmock() - connection.should_receive('transaction').twice().and_return(connection) - resolver.should_receive('connection').and_return(connection) + connection.should_receive("transaction").twice().and_return(connection) + resolver.should_receive("connection").and_return(connection) migrator = flexmock( Migrator( - flexmock( - DatabaseMigrationRepository( - resolver, - 'migrations' - ) - ), - resolver + flexmock(DatabaseMigrationRepository(resolver, "migrations")), resolver ) ) g = flexmock(glob) - g.should_receive('glob').with_args(os.path.join(os.getcwd(), '[0-9]*_*.py')).and_return([ - os.path.join(os.getcwd(), '2_bar.py'), - os.path.join(os.getcwd(), '1_foo.py'), - os.path.join(os.getcwd(), '3_baz.py') - ]) - - migrator.get_repository().should_receive('get_ran').once().and_return(['1_foo']) - migrator.get_repository().should_receive('get_next_batch_number').once().and_return(1) - migrator.get_repository().should_receive('log').once().with_args('2_bar', 1) - migrator.get_repository().should_receive('log').once().with_args('3_baz', 1) + g.should_receive("glob").with_args( + os.path.join(os.getcwd(), "[0-9]*_*.py") + ).and_return( + [ + os.path.join(os.getcwd(), "2_bar.py"), + os.path.join(os.getcwd(), "1_foo.py"), + os.path.join(os.getcwd(), "3_baz.py"), + ] + ) + + migrator.get_repository().should_receive("get_ran").once().and_return(["1_foo"]) + migrator.get_repository().should_receive( + "get_next_batch_number" + ).once().and_return(1) + migrator.get_repository().should_receive("log").once().with_args("2_bar", 1) + migrator.get_repository().should_receive("log").once().with_args("3_baz", 1) bar_mock = flexmock(MigrationStub()) bar_mock.set_connection(connection) - bar_mock.should_receive('up').once() + bar_mock.should_receive("up").once() baz_mock = flexmock(MigrationStub()) baz_mock.set_connection(connection) - baz_mock.should_receive('up').once() - migrator.should_receive('_resolve').with_args(os.getcwd(), '2_bar').once().and_return(bar_mock) - migrator.should_receive('_resolve').with_args(os.getcwd(), '3_baz').once().and_return(baz_mock) + baz_mock.should_receive("up").once() + migrator.should_receive("_resolve").with_args( + os.getcwd(), "2_bar" + ).once().and_return(bar_mock) + migrator.should_receive("_resolve").with_args( + os.getcwd(), "3_baz" + ).once().and_return(baz_mock) migrator.run(os.getcwd()) def test_migrations_are_run_up_directly_if_transactional_is_false(self): resolver_mock = flexmock(DatabaseManager) - resolver_mock.should_receive('connection').and_return({}) + resolver_mock.should_receive("connection").and_return({}) resolver = flexmock(DatabaseManager({})) connection = flexmock() - connection.should_receive('transaction').never() - resolver.should_receive('connection').and_return(connection) + connection.should_receive("transaction").never() + resolver.should_receive("connection").and_return(connection) migrator = flexmock( Migrator( - flexmock( - DatabaseMigrationRepository( - resolver, - 'migrations' - ) - ), - resolver + flexmock(DatabaseMigrationRepository(resolver, "migrations")), resolver ) ) g = flexmock(glob) - g.should_receive('glob').with_args(os.path.join(os.getcwd(), '[0-9]*_*.py')).and_return([ - os.path.join(os.getcwd(), '2_bar.py'), - os.path.join(os.getcwd(), '1_foo.py'), - os.path.join(os.getcwd(), '3_baz.py') - ]) - - migrator.get_repository().should_receive('get_ran').once().and_return(['1_foo']) - migrator.get_repository().should_receive('get_next_batch_number').once().and_return(1) - migrator.get_repository().should_receive('log').once().with_args('2_bar', 1) - migrator.get_repository().should_receive('log').once().with_args('3_baz', 1) + g.should_receive("glob").with_args( + os.path.join(os.getcwd(), "[0-9]*_*.py") + ).and_return( + [ + os.path.join(os.getcwd(), "2_bar.py"), + os.path.join(os.getcwd(), "1_foo.py"), + os.path.join(os.getcwd(), "3_baz.py"), + ] + ) + + migrator.get_repository().should_receive("get_ran").once().and_return(["1_foo"]) + migrator.get_repository().should_receive( + "get_next_batch_number" + ).once().and_return(1) + migrator.get_repository().should_receive("log").once().with_args("2_bar", 1) + migrator.get_repository().should_receive("log").once().with_args("3_baz", 1) bar_mock = flexmock(MigrationStub()) bar_mock.transactional = False bar_mock.set_connection(connection) - bar_mock.should_receive('up').once() + bar_mock.should_receive("up").once() baz_mock = flexmock(MigrationStub()) baz_mock.transactional = False baz_mock.set_connection(connection) - baz_mock.should_receive('up').once() - migrator.should_receive('_resolve').with_args(os.getcwd(), '2_bar').once().and_return(bar_mock) - migrator.should_receive('_resolve').with_args(os.getcwd(), '3_baz').once().and_return(baz_mock) + baz_mock.should_receive("up").once() + migrator.should_receive("_resolve").with_args( + os.getcwd(), "2_bar" + ).once().and_return(bar_mock) + migrator.should_receive("_resolve").with_args( + os.getcwd(), "3_baz" + ).once().and_return(baz_mock) migrator.run(os.getcwd()) def test_up_migration_can_be_pretended(self): resolver_mock = flexmock(DatabaseManager) - resolver_mock.should_receive('connection').and_return({}) + resolver_mock.should_receive("connection").and_return({}) resolver = flexmock(DatabaseManager({})) connection = flexmock(Connection(None)) - connection.should_receive('get_logged_queries').twice().and_return([]) - resolver.should_receive('connection').with_args(None).and_return(connection) + connection.should_receive("get_logged_queries").twice().and_return([]) + resolver.should_receive("connection").with_args(None).and_return(connection) migrator = flexmock( Migrator( - flexmock( - DatabaseMigrationRepository( - resolver, - 'migrations' - ) - ), - resolver + flexmock(DatabaseMigrationRepository(resolver, "migrations")), resolver ) ) g = flexmock(glob) - g.should_receive('glob').with_args(os.path.join(os.getcwd(), '[0-9]*_*.py')).and_return([ - os.path.join(os.getcwd(), '2_bar.py'), - os.path.join(os.getcwd(), '1_foo.py'), - os.path.join(os.getcwd(), '3_baz.py') - ]) - - migrator.get_repository().should_receive('get_ran').once().and_return(['1_foo']) - migrator.get_repository().should_receive('get_next_batch_number').once().and_return(1) + g.should_receive("glob").with_args( + os.path.join(os.getcwd(), "[0-9]*_*.py") + ).and_return( + [ + os.path.join(os.getcwd(), "2_bar.py"), + os.path.join(os.getcwd(), "1_foo.py"), + os.path.join(os.getcwd(), "3_baz.py"), + ] + ) + + migrator.get_repository().should_receive("get_ran").once().and_return(["1_foo"]) + migrator.get_repository().should_receive( + "get_next_batch_number" + ).once().and_return(1) bar_mock = flexmock(MigrationStub()) - bar_mock.should_receive('get_connection').once().and_return(connection) - bar_mock.should_receive('up').once() + bar_mock.should_receive("get_connection").once().and_return(connection) + bar_mock.should_receive("up").once() baz_mock = flexmock(MigrationStub()) - baz_mock.should_receive('get_connection').once().and_return(connection) - baz_mock.should_receive('up').once() - migrator.should_receive('_resolve').with_args(os.getcwd(), '2_bar').once().and_return(bar_mock) - migrator.should_receive('_resolve').with_args(os.getcwd(), '3_baz').once().and_return(baz_mock) + baz_mock.should_receive("get_connection").once().and_return(connection) + baz_mock.should_receive("up").once() + migrator.should_receive("_resolve").with_args( + os.getcwd(), "2_bar" + ).once().and_return(bar_mock) + migrator.should_receive("_resolve").with_args( + os.getcwd(), "3_baz" + ).once().and_return(baz_mock) migrator.run(os.getcwd(), True) def test_nothing_is_done_when_no_migrations_outstanding(self): resolver_mock = flexmock(DatabaseManager) - resolver_mock.should_receive('connection').and_return(None) + resolver_mock.should_receive("connection").and_return(None) resolver = flexmock(DatabaseManager({})) migrator = flexmock( Migrator( - flexmock( - DatabaseMigrationRepository( - resolver, - 'migrations' - ) - ), - resolver + flexmock(DatabaseMigrationRepository(resolver, "migrations")), resolver ) ) g = flexmock(glob) - g.should_receive('glob').with_args(os.path.join(os.getcwd(), '[0-9]*_*.py')).and_return([ - os.path.join(os.getcwd(), '1_foo.py') - ]) + g.should_receive("glob").with_args( + os.path.join(os.getcwd(), "[0-9]*_*.py") + ).and_return([os.path.join(os.getcwd(), "1_foo.py")]) - migrator.get_repository().should_receive('get_ran').once().and_return(['1_foo']) + migrator.get_repository().should_receive("get_ran").once().and_return(["1_foo"]) migrator.run(os.getcwd()) def test_last_batch_of_migrations_can_be_rolled_back(self): resolver_mock = flexmock(DatabaseManager) - resolver_mock.should_receive('connection').and_return({}) + resolver_mock.should_receive("connection").and_return({}) resolver = flexmock(DatabaseManager({})) connection = flexmock() - connection.should_receive('transaction').twice().and_return(connection) - resolver.should_receive('connection').and_return(connection) + connection.should_receive("transaction").twice().and_return(connection) + resolver.should_receive("connection").and_return(connection) migrator = flexmock( Migrator( - flexmock( - DatabaseMigrationRepository( - resolver, - 'migrations' - ) - ), - resolver + flexmock(DatabaseMigrationRepository(resolver, "migrations")), resolver ) ) - foo_migration = MigrationStub('foo') - bar_migration = MigrationStub('bar') - migrator.get_repository().should_receive('get_last').once().and_return([ - foo_migration, - bar_migration - ]) + foo_migration = MigrationStub("foo") + bar_migration = MigrationStub("bar") + migrator.get_repository().should_receive("get_last").once().and_return( + [foo_migration, bar_migration] + ) bar_mock = flexmock(MigrationStub()) bar_mock.set_connection(connection) - bar_mock.should_receive('down').once() + bar_mock.should_receive("down").once() foo_mock = flexmock(MigrationStub()) foo_mock.set_connection(connection) - foo_mock.should_receive('down').once() - migrator.should_receive('_resolve').with_args(os.getcwd(), 'bar').once().and_return(bar_mock) - migrator.should_receive('_resolve').with_args(os.getcwd(), 'foo').once().and_return(foo_mock) - - migrator.get_repository().should_receive('delete').once().with_args(bar_migration) - migrator.get_repository().should_receive('delete').once().with_args(foo_migration) + foo_mock.should_receive("down").once() + migrator.should_receive("_resolve").with_args( + os.getcwd(), "bar" + ).once().and_return(bar_mock) + migrator.should_receive("_resolve").with_args( + os.getcwd(), "foo" + ).once().and_return(foo_mock) + + migrator.get_repository().should_receive("delete").once().with_args( + bar_migration + ) + migrator.get_repository().should_receive("delete").once().with_args( + foo_migration + ) migrator.rollback(os.getcwd()) - def test_last_batch_of_migrations_can_be_rolled_back_directly_if_transactional_is_false(self): + def test_last_batch_of_migrations_can_be_rolled_back_directly_if_transactional_is_false( + self + ): resolver_mock = flexmock(DatabaseManager) - resolver_mock.should_receive('connection').and_return({}) + resolver_mock.should_receive("connection").and_return({}) resolver = flexmock(DatabaseManager({})) connection = flexmock() - connection.should_receive('transaction').never() - resolver.should_receive('connection').and_return(connection) + connection.should_receive("transaction").never() + resolver.should_receive("connection").and_return(connection) migrator = flexmock( Migrator( - flexmock( - DatabaseMigrationRepository( - resolver, - 'migrations' - ) - ), - resolver + flexmock(DatabaseMigrationRepository(resolver, "migrations")), resolver ) ) - foo_migration = MigrationStub('foo') - bar_migration = MigrationStub('bar') - migrator.get_repository().should_receive('get_last').once().and_return([ - foo_migration, - bar_migration - ]) + foo_migration = MigrationStub("foo") + bar_migration = MigrationStub("bar") + migrator.get_repository().should_receive("get_last").once().and_return( + [foo_migration, bar_migration] + ) bar_mock = flexmock(MigrationStub()) bar_mock.transactional = False bar_mock.set_connection(connection) - bar_mock.should_receive('down').once() + bar_mock.should_receive("down").once() foo_mock = flexmock(MigrationStub()) foo_mock.transactional = False foo_mock.set_connection(connection) - foo_mock.should_receive('down').once() - migrator.should_receive('_resolve').with_args(os.getcwd(), 'bar').once().and_return(bar_mock) - migrator.should_receive('_resolve').with_args(os.getcwd(), 'foo').once().and_return(foo_mock) - - migrator.get_repository().should_receive('delete').once().with_args(bar_migration) - migrator.get_repository().should_receive('delete').once().with_args(foo_migration) + foo_mock.should_receive("down").once() + migrator.should_receive("_resolve").with_args( + os.getcwd(), "bar" + ).once().and_return(bar_mock) + migrator.should_receive("_resolve").with_args( + os.getcwd(), "foo" + ).once().and_return(foo_mock) + + migrator.get_repository().should_receive("delete").once().with_args( + bar_migration + ) + migrator.get_repository().should_receive("delete").once().with_args( + foo_migration + ) migrator.rollback(os.getcwd()) def test_rollback_migration_can_be_pretended(self): resolver_mock = flexmock(DatabaseManager) - resolver_mock.should_receive('connection').and_return({}) + resolver_mock.should_receive("connection").and_return({}) resolver = flexmock(DatabaseManager({})) connection = flexmock(Connection(None)) - connection.should_receive('get_logged_queries').twice().and_return([]) - resolver.should_receive('connection').with_args(None).and_return(connection) + connection.should_receive("get_logged_queries").twice().and_return([]) + resolver.should_receive("connection").with_args(None).and_return(connection) migrator = flexmock( Migrator( - flexmock( - DatabaseMigrationRepository( - resolver, - 'migrations' - ) - ), - resolver + flexmock(DatabaseMigrationRepository(resolver, "migrations")), resolver ) ) - foo_migration = flexmock(MigrationStub('foo')) - foo_migration.should_receive('get_connection').and_return(connection) - bar_migration = flexmock(MigrationStub('bar')) - bar_migration.should_receive('get_connection').and_return(connection) - migrator.get_repository().should_receive('get_last').once().and_return([ - foo_migration, - bar_migration - ]) + foo_migration = flexmock(MigrationStub("foo")) + foo_migration.should_receive("get_connection").and_return(connection) + bar_migration = flexmock(MigrationStub("bar")) + bar_migration.should_receive("get_connection").and_return(connection) + migrator.get_repository().should_receive("get_last").once().and_return( + [foo_migration, bar_migration] + ) - migrator.should_receive('_resolve').with_args(os.getcwd(), 'bar').once().and_return(bar_migration) - migrator.should_receive('_resolve').with_args(os.getcwd(), 'foo').once().and_return(foo_migration) + migrator.should_receive("_resolve").with_args( + os.getcwd(), "bar" + ).once().and_return(bar_migration) + migrator.should_receive("_resolve").with_args( + os.getcwd(), "foo" + ).once().and_return(foo_migration) migrator.rollback(os.getcwd(), True) @@ -301,27 +307,20 @@ def test_rollback_migration_can_be_pretended(self): def test_nothing_is_rolled_back_when_nothing_in_repository(self): resolver = flexmock(DatabaseManager) - resolver.should_receive('connection').and_return(None) + resolver.should_receive("connection").and_return(None) migrator = flexmock( Migrator( - flexmock( - DatabaseMigrationRepository( - resolver, - 'migrations' - ) - ), - resolver + flexmock(DatabaseMigrationRepository(resolver, "migrations")), resolver ) ) - migrator.get_repository().should_receive('get_last').once().and_return([]) + migrator.get_repository().should_receive("get_last").once().and_return([]) migrator.rollback(os.getcwd()) class MigrationStub(Migration): - def __init__(self, migration=None): self.migration = migration self.upped = False diff --git a/tests/orm/__init__.py b/tests/orm/__init__.py index 633f8661..40a96afc 100644 --- a/tests/orm/__init__.py +++ b/tests/orm/__init__.py @@ -1,2 +1 @@ # -*- coding: utf-8 -*- - diff --git a/tests/orm/mixins/__init__.py b/tests/orm/mixins/__init__.py index 633f8661..40a96afc 100644 --- a/tests/orm/mixins/__init__.py +++ b/tests/orm/mixins/__init__.py @@ -1,2 +1 @@ # -*- coding: utf-8 -*- - diff --git a/tests/orm/mixins/test_soft_deletes.py b/tests/orm/mixins/test_soft_deletes.py index 768bead2..03c037a0 100644 --- a/tests/orm/mixins/test_soft_deletes.py +++ b/tests/orm/mixins/test_soft_deletes.py @@ -13,7 +13,6 @@ class SoftDeletesTestCase(OratorTestCase): - def tearDown(self): flexmock_teardown() @@ -23,9 +22,9 @@ def test_delete_sets_soft_deleted_column(self): builder = flexmock(Builder) query_builder = flexmock(QueryBuilder(None, None, None)) query = Builder(query_builder) - model.should_receive('new_query').and_return(query) - builder.should_receive('where').once().with_args('id', 1).and_return(query) - builder.should_receive('update').once().with_args({'deleted_at': t}) + model.should_receive("new_query").and_return(query) + builder.should_receive("where").once().with_args("id", 1).and_return(query) + builder.should_receive("update").once().with_args({"deleted_at": t}) model.delete() self.assertIsInstance(model.deleted_at, datetime.datetime) @@ -33,7 +32,7 @@ def test_delete_sets_soft_deleted_column(self): def test_restore(self): model = flexmock(SoftDeleteModelStub()) model.set_exists(True) - model.should_receive('save').once() + model.should_receive("save").once() model.restore() @@ -41,12 +40,11 @@ def test_restore(self): class SoftDeleteModelStub(SoftDeletes, Model): - def get_key(self): return 1 def get_key_name(self): - return 'id' + return "id" def from_datetime(self, value): return t diff --git a/tests/orm/mixins/test_soft_deletes_integration.py b/tests/orm/mixins/test_soft_deletes_integration.py index bbcf5636..6b806bba 100644 --- a/tests/orm/mixins/test_soft_deletes_integration.py +++ b/tests/orm/mixins/test_soft_deletes_integration.py @@ -9,12 +9,7 @@ class SoftDeletesIntegrationTestCase(OratorTestCase): - databases = { - 'test': { - 'driver': 'sqlite', - 'database': ':memory:' - } - } + databases = {"test": {"driver": "sqlite", "database": ":memory:"}} def setUp(self): self.db = DatabaseManager(self.databases) @@ -24,30 +19,30 @@ def setUp(self): self.create_schema() def create_schema(self): - with self.schema().create('users') as table: - table.increments('id') - table.string('email').unique() + with self.schema().create("users") as table: + table.increments("id") + table.string("email").unique() table.timestamps() table.soft_deletes() - with self.schema().create('posts') as table: - table.increments('id') - table.string('title') - table.integer('user_id') + with self.schema().create("posts") as table: + table.increments("id") + table.string("title") + table.integer("user_id") table.timestamps() table.soft_deletes() - with self.schema().create('comments') as table: - table.increments('id') - table.string('body') - table.integer('post_id') + with self.schema().create("comments") as table: + table.increments("id") + table.string("body") + table.integer("post_id") table.timestamps() table.soft_deletes() def tearDown(self): - self.schema().drop('users') - self.schema().drop('posts') - self.schema().drop('comments') + self.schema().drop("users") + self.schema().drop("posts") + self.schema().drop("comments") Model.unset_connection_resolver() @@ -79,7 +74,7 @@ def test_soft_deletes_are_not_retrieved_from_builder_helpers(self): self.assertEqual(1, count) query = SoftDeletesTestUser.query() - self.assertEqual(1, len(query.lists('email'))) + self.assertEqual(1, len(query.lists("email"))) Paginator.current_page_resolver(lambda: 1) query = SoftDeletesTestUser.query() @@ -89,15 +84,20 @@ def test_soft_deletes_are_not_retrieved_from_builder_helpers(self): query = SoftDeletesTestUser.query() self.assertEqual(1, len(query.simple_paginate(2).items)) - self.assertEqual(0, SoftDeletesTestUser.where('email', 'john@doe.com').increment('id')) - self.assertEqual(0, SoftDeletesTestUser.where('email', 'john@doe.com').decrement('id')) - + self.assertEqual( + 0, SoftDeletesTestUser.where("email", "john@doe.com").increment("id") + ) + self.assertEqual( + 0, SoftDeletesTestUser.where("email", "john@doe.com").decrement("id") + ) def test_with_trashed_returns_all_records(self): self.create_users() self.assertEqual(2, SoftDeletesTestUser.with_trashed().get().count()) - self.assertIsInstance(SoftDeletesTestUser.with_trashed().find(1), SoftDeletesTestUser) + self.assertIsInstance( + SoftDeletesTestUser.with_trashed().find(1), SoftDeletesTestUser + ) def test_delete_sets_deleted_column(self): self.create_users() @@ -142,63 +142,83 @@ def test_first_or_new_ignores_soft_deletes(self): self.create_users() john = SoftDeletesTestUser.first_or_new(id=1) - self.assertEqual('john@doe.com', john.email) + self.assertEqual("john@doe.com", john.email) def test_where_has_with_deleted_relationship(self): self.create_users() - jane = SoftDeletesTestUser.where('email', 'jane@doe.com').first() - post = jane.posts().create(title='First Title') + jane = SoftDeletesTestUser.where("email", "jane@doe.com").first() + post = jane.posts().create(title="First Title") - users = SoftDeletesTestUser.where('email', 'john@doe.com').has('posts').get() + users = SoftDeletesTestUser.where("email", "john@doe.com").has("posts").get() self.assertEqual(0, len(users)) - users = SoftDeletesTestUser.where('email', 'jane@doe.com').has('posts').get() + users = SoftDeletesTestUser.where("email", "jane@doe.com").has("posts").get() self.assertEqual(1, len(users)) - users = SoftDeletesTestUser.where('email', 'doesnt@exist.com').or_has('posts').get() + users = ( + SoftDeletesTestUser.where("email", "doesnt@exist.com").or_has("posts").get() + ) self.assertEqual(1, len(users)) - users = SoftDeletesTestUser.where_has('posts', lambda q: q.where('title', 'First Title')).get() + users = SoftDeletesTestUser.where_has( + "posts", lambda q: q.where("title", "First Title") + ).get() self.assertEqual(1, len(users)) - users = SoftDeletesTestUser.where_has('posts', lambda q: q.where('title', 'Another Title')).get() + users = SoftDeletesTestUser.where_has( + "posts", lambda q: q.where("title", "Another Title") + ).get() self.assertEqual(0, len(users)) - users = SoftDeletesTestUser.where('email', 'doesnt@exist.com')\ - .or_where_has('posts', lambda q: q.where('title', 'First Title'))\ + users = ( + SoftDeletesTestUser.where("email", "doesnt@exist.com") + .or_where_has("posts", lambda q: q.where("title", "First Title")) .get() + ) self.assertEqual(1, len(users)) # With post delete post.delete() - users = SoftDeletesTestUser.has('posts').get() + users = SoftDeletesTestUser.has("posts").get() self.assertEqual(0, len(users)) def test_where_has_with_nested_deleted_relationship(self): self.create_users() - jane = SoftDeletesTestUser.where('email', 'jane@doe.com').first() - post = jane.posts().create(title='First Title') - comment = post.comments().create(body='Comment Body') + jane = SoftDeletesTestUser.where("email", "jane@doe.com").first() + post = jane.posts().create(title="First Title") + comment = post.comments().create(body="Comment Body") comment.delete() - users = SoftDeletesTestUser.has('posts.comments').get() + users = SoftDeletesTestUser.has("posts.comments").get() self.assertEqual(0, len(users)) - users = SoftDeletesTestUser.doesnt_have('posts.comments').get() + users = SoftDeletesTestUser.doesnt_have("posts.comments").get() self.assertEqual(1, len(users)) def test_or_where_with_soft_deletes_constraint(self): self.create_users() - users = SoftDeletesTestUser.where('email', 'john@doe.com').or_where('email', 'jane@doe.com') + users = SoftDeletesTestUser.where("email", "john@doe.com").or_where( + "email", "jane@doe.com" + ) + self.assertEqual(1, len(users.get())) + self.assertEqual(["jane@doe.com"], users.order_by("id").lists("email")) + + def test_where_exists_on_soft_delete_model(self): + self.create_users() + + users = SoftDeletesTestUser.where_exists( + SoftDeletesTestUser.where("email", "jane@doe.com") + ) + self.assertEqual(1, len(users.get())) - self.assertEqual(['jane@doe.com'], users.order_by('id').lists('email')) + self.assertEqual(["jane@doe.com"], users.order_by("id").lists("email")) def create_users(self): - john = SoftDeletesTestUser.create(email='john@doe.com') - jane = SoftDeletesTestUser.create(email='jane@doe.com') + john = SoftDeletesTestUser.create(email="john@doe.com") + jane = SoftDeletesTestUser.create(email="jane@doe.com") john.delete() @@ -211,9 +231,9 @@ def schema(self): class SoftDeletesTestUser(SoftDeletes, Model): - __table__ = 'users' + __table__ = "users" - __dates__ = ['deleted_at'] + __dates__ = ["deleted_at"] __guarded__ = [] @@ -224,9 +244,9 @@ def posts(self): class SoftDeletesTestPost(SoftDeletes, Model): - __table__ = 'posts' + __table__ = "posts" - __dates__ = ['deleted_at'] + __dates__ = ["deleted_at"] __guarded__ = [] @@ -237,8 +257,8 @@ def comments(self): class SoftDeletesTestComment(SoftDeletes, Model): - __table__ = 'comments' + __table__ = "comments" - __dates__ = ['deleted_at'] + __dates__ = ["deleted_at"] __guarded__ = [] diff --git a/tests/orm/models.py b/tests/orm/models.py index 668ce86e..75b600d5 100644 --- a/tests/orm/models.py +++ b/tests/orm/models.py @@ -5,4 +5,4 @@ class User(Model): - __fillable__ = ['name'] + __fillable__ = ["name"] diff --git a/tests/orm/relations/__init__.py b/tests/orm/relations/__init__.py index 633f8661..40a96afc 100644 --- a/tests/orm/relations/__init__.py +++ b/tests/orm/relations/__init__.py @@ -1,2 +1 @@ # -*- coding: utf-8 -*- - diff --git a/tests/orm/relations/test_belongs_to.py b/tests/orm/relations/test_belongs_to.py index bba99566..e6325046 100644 --- a/tests/orm/relations/test_belongs_to.py +++ b/tests/orm/relations/test_belongs_to.py @@ -15,37 +15,37 @@ class OrmBelongsToTestCase(OratorTestCase): - def tearDown(self): flexmock_teardown() def test_update_retrieve_model_and_updates(self): relation = self._get_relation() mock = flexmock(Model()) - mock.should_receive('fill').once().with_args({'foo': 'bar'}).and_return(mock) - mock.should_receive('save').once().and_return(True) - relation.get_query().should_receive('first').once().and_return(mock) + mock.should_receive("fill").once().with_args({"foo": "bar"}).and_return(mock) + mock.should_receive("save").once().and_return(True) + relation.get_query().should_receive("first").once().and_return(mock) - self.assertTrue(relation.update({'foo': 'bar'})) + self.assertTrue(relation.update({"foo": "bar"})) def test_relation_is_properly_initialized(self): relation = self._get_relation() model = flexmock(Model()) - model.should_receive('set_relation').once().with_args('foo', None) - models = relation.init_relation([model], 'foo') + model.should_receive("set_relation").once().with_args("foo", None) + models = relation.init_relation([model], "foo") self.assertEqual([model], models) def test_eager_constraints_are_properly_added(self): relation = self._get_relation() - relation.get_query().get_query().should_receive('where_in').once()\ - .with_args('relation.id', ['foreign.value', 'foreign.value.two']) + relation.get_query().get_query().should_receive("where_in").once().with_args( + "relation.id", ["foreign.value", "foreign.value.two"] + ) model1 = OrmBelongsToModelStub() model2 = OrmBelongsToModelStub() - model2.foreign_key = 'foreign.value' + model2.foreign_key = "foreign.value" model3 = AnotherOrmBelongsToModelStub() - model3.foreign_key = 'foreign.value.two' + model3.foreign_key = "foreign.value.two" models = [model1, model2, model3] relation.add_eager_constraints(models) @@ -54,63 +54,60 @@ def test_models_are_properly_matched_to_parents(self): relation = self._get_relation() result1 = flexmock() - result1.should_receive('get_attribute').with_args('id').and_return(1) + result1.should_receive("get_attribute").with_args("id").and_return(1) result2 = flexmock() - result2.should_receive('get_attribute').with_args('id').and_return(2) + result2.should_receive("get_attribute").with_args("id").and_return(2) model1 = OrmBelongsToModelStub() model1.foreign_key = 1 model2 = OrmBelongsToModelStub() model2.foreign_key = 2 - models = relation.match([model1, model2], Collection([result1, result2]), 'foo') + models = relation.match([model1, model2], Collection([result1, result2]), "foo") - self.assertEqual(1, models[0].foo.get_attribute('id')) - self.assertEqual(2, models[1].foo.get_attribute('id')) + self.assertEqual(1, models[0].foo.get_attribute("id")) + self.assertEqual(2, models[1].foo.get_attribute("id")) def test_associate_sets_foreign_key_on_model(self): parent = Model() - parent.foreign_key = 'foreign.value' - parent.get_attribute = mock.MagicMock(return_value='foreign.value') + parent.foreign_key = "foreign.value" + parent.get_attribute = mock.MagicMock(return_value="foreign.value") parent.set_attribute = mock.MagicMock() parent.set_relation = mock.MagicMock() relation = self._get_relation(parent) associate = flexmock(Model()) - associate.should_receive('get_attribute').once().with_args('id').and_return(1) + associate.should_receive("get_attribute").once().with_args("id").and_return(1) relation.associate(associate) - parent.get_attribute.assert_has_calls([ - mock.call('foreign_key'), - mock.call('foreign_key') - ]) - parent.set_attribute.assert_has_calls([ - mock.call('foreign_key', 1) - ]) - parent.set_relation.assert_called_once_with('relation', associate) + parent.get_attribute.assert_has_calls( + [mock.call("foreign_key"), mock.call("foreign_key")] + ) + parent.set_attribute.assert_has_calls([mock.call("foreign_key", 1)]) + parent.set_relation.assert_called_once_with("relation", associate) def _get_relation(self, parent=None): flexmock(Builder) query = flexmock(QueryBuilder(None, QueryGrammar(), None)) builder = Builder(query) - builder.should_receive('where').with_args('relation.id', '=', 'foreign.value') + builder.should_receive("where").with_args("relation.id", "=", "foreign.value") related = flexmock(Model()) - related.should_receive('new_query').and_return(builder) - related.should_receive('get_key_name').and_return('id') - related.should_receive('get_table').and_return('relation') - builder.should_receive('get_model').and_return(related) + related.should_receive("new_query").and_return(builder) + related.should_receive("get_key_name").and_return("id") + related.should_receive("get_table").and_return("relation") + builder.should_receive("get_model").and_return(related) if parent is None: parent = OrmBelongsToModelStub() - parent.foreign_key = 'foreign.value' + parent.foreign_key = "foreign.value" - return BelongsTo(builder, parent, 'foreign_key', 'id', 'relation') + return BelongsTo(builder, parent, "foreign_key", "id", "relation") class OrmBelongsToModelStub(Model): - foreign_key = 'foreign.value' + foreign_key = "foreign.value" class AnotherOrmBelongsToModelStub(Model): - foreign_key = 'foreign.value.two' + foreign_key = "foreign.value.two" diff --git a/tests/orm/relations/test_belongs_to_many.py b/tests/orm/relations/test_belongs_to_many.py index 0bb0dd01..ef3a8588 100644 --- a/tests/orm/relations/test_belongs_to_many.py +++ b/tests/orm/relations/test_belongs_to_many.py @@ -18,76 +18,99 @@ class OrmBelongsToTestCase(OratorTestCase): - def tearDown(self): flexmock_teardown() def test_models_are_properly_hydrated(self): model1 = OrmBelongsToManyModelStub() - model1.fill(name='john', pivot_user_id=1, pivot_role_id=2) + model1.fill(name="john", pivot_user_id=1, pivot_role_id=2) model2 = OrmBelongsToManyModelStub() - model2.fill(name='jane', pivot_user_id=3, pivot_role_id=4) + model2.fill(name="jane", pivot_user_id=3, pivot_role_id=4) models = [model1, model2] - base_builder = flexmock(Builder(QueryBuilder(MockConnection().prepare_mock(), - QueryGrammar(), QueryProcessor()))) + base_builder = flexmock( + Builder( + QueryBuilder( + MockConnection().prepare_mock(), QueryGrammar(), QueryProcessor() + ) + ) + ) relation = self._get_relation() - relation.get_parent().should_receive('get_connection_name').and_return('foo.connection') - relation.get_query().get_query().should_receive('add_select').once()\ - .with_args(*['roles.*', 'user_role.user_id AS pivot_user_id', 'user_role.role_id AS pivot_role_id'])\ - .and_return(relation.get_query()) - relation.get_query().should_receive('get_models').once().and_return(models) - relation.get_query().should_receive('eager_load_relations').once().with_args(models).and_return(models) - relation.get_related().should_receive('new_collection').replace_with(lambda l: Collection(l)) - relation.get_query().should_receive('get_query').once().and_return(base_builder) + relation.get_parent().should_receive("get_connection_name").and_return( + "foo.connection" + ) + relation.get_query().get_query().should_receive("add_select").once().with_args( + *[ + "roles.*", + "user_role.user_id AS pivot_user_id", + "user_role.role_id AS pivot_role_id", + ] + ).and_return(relation.get_query()) + relation.get_query().should_receive("get_models").once().and_return(models) + relation.get_query().should_receive("eager_load_relations").once().with_args( + models + ).and_return(models) + relation.get_related().should_receive("new_collection").replace_with( + lambda l: Collection(l) + ) + relation.get_query().should_receive("get_query").once().and_return(base_builder) results = relation.get() self.assertIsInstance(results, Collection) # Make sure the foreign keys were set on the pivot models - self.assertEqual('user_id', results[0].pivot.get_foreign_key()) - self.assertEqual('role_id', results[0].pivot.get_other_key()) + self.assertEqual("user_id", results[0].pivot.get_foreign_key()) + self.assertEqual("role_id", results[0].pivot.get_other_key()) - self.assertEqual('john', results[0].name) + self.assertEqual("john", results[0].name) self.assertEqual(1, results[0].pivot.user_id) self.assertEqual(2, results[0].pivot.role_id) - self.assertEqual('foo.connection', results[0].pivot.get_connection_name()) + self.assertEqual("foo.connection", results[0].pivot.get_connection_name()) - self.assertEqual('jane', results[1].name) + self.assertEqual("jane", results[1].name) self.assertEqual(3, results[1].pivot.user_id) self.assertEqual(4, results[1].pivot.role_id) - self.assertEqual('foo.connection', results[1].pivot.get_connection_name()) + self.assertEqual("foo.connection", results[1].pivot.get_connection_name()) - self.assertEqual('user_role', results[0].pivot.get_table()) + self.assertEqual("user_role", results[0].pivot.get_table()) self.assertTrue(results[0].pivot.exists) def test_timestamps_can_be_retrieved_properly(self): model1 = OrmBelongsToManyModelStub() - model1.fill(name='john', pivot_user_id=1, pivot_role_id=2) + model1.fill(name="john", pivot_user_id=1, pivot_role_id=2) model2 = OrmBelongsToManyModelStub() - model2.fill(name='jane', pivot_user_id=3, pivot_role_id=4) + model2.fill(name="jane", pivot_user_id=3, pivot_role_id=4) models = [model1, model2] - base_builder = flexmock(Builder(QueryBuilder(MockConnection().prepare_mock(), - QueryGrammar(), QueryProcessor()))) + base_builder = flexmock( + Builder( + QueryBuilder( + MockConnection().prepare_mock(), QueryGrammar(), QueryProcessor() + ) + ) + ) relation = self._get_relation().with_timestamps() - relation.get_parent().should_receive('get_connection_name').and_return('foo.connection') - relation.get_query().get_query().should_receive('add_select').once()\ - .with_args( - 'roles.*', - 'user_role.user_id AS pivot_user_id', - 'user_role.role_id AS pivot_role_id', - 'user_role.created_at AS pivot_created_at', - 'user_role.updated_at AS pivot_updated_at' - )\ - .and_return(relation.get_query()) - relation.get_query().should_receive('get_models').once().and_return(models) - relation.get_query().should_receive('eager_load_relations').once().with_args(models).and_return(models) - relation.get_related().should_receive('new_collection').replace_with(lambda l: Collection(l)) - relation.get_query().should_receive('get_query').once().and_return(base_builder) + relation.get_parent().should_receive("get_connection_name").and_return( + "foo.connection" + ) + relation.get_query().get_query().should_receive("add_select").once().with_args( + "roles.*", + "user_role.user_id AS pivot_user_id", + "user_role.role_id AS pivot_role_id", + "user_role.created_at AS pivot_created_at", + "user_role.updated_at AS pivot_updated_at", + ).and_return(relation.get_query()) + relation.get_query().should_receive("get_models").once().and_return(models) + relation.get_query().should_receive("eager_load_relations").once().with_args( + models + ).and_return(models) + relation.get_related().should_receive("new_collection").replace_with( + lambda l: Collection(l) + ) + relation.get_query().should_receive("get_query").once().and_return(base_builder) results = relation.get() @@ -110,8 +133,12 @@ def test_models_are_properly_matched_to_parents(self): model3 = OrmBelongsToManyModelStub() model3.id = 3 - relation.get_related().should_receive('new_collection').replace_with(lambda l=None: Collection(l)) - models = relation.match([model1, model2, model3], Collection([result1, result2, result3]), 'foo') + relation.get_related().should_receive("new_collection").replace_with( + lambda l=None: Collection(l) + ) + models = relation.match( + [model1, model2, model3], Collection([result1, result2, result3]), "foo" + ) self.assertEqual(1, models[0].foo[0].pivot.user_id) self.assertEqual(1, len(models[0].foo)) @@ -122,16 +149,20 @@ def test_models_are_properly_matched_to_parents(self): def test_relation_is_properly_initialized(self): relation = self._get_relation() - relation.get_related().should_receive('new_collection').replace_with(lambda l=None: Collection(l or [])) + relation.get_related().should_receive("new_collection").replace_with( + lambda l=None: Collection(l or []) + ) model = flexmock(Model()) - model.should_receive('set_relation').once().with_args('foo', Collection) - models = relation.init_relation([model], 'foo') + model.should_receive("set_relation").once().with_args("foo", Collection) + models = relation.init_relation([model], "foo") self.assertEqual([model], models) def test_eager_constraints_are_properly_added(self): relation = self._get_relation() - relation.get_query().get_query().should_receive('where_in').once().with_args('user_role.user_id', [1, 2]) + relation.get_query().get_query().should_receive("where_in").once().with_args( + "user_role.user_id", [1, 2] + ) model1 = OrmBelongsToManyModelStub() model1.id = 1 model2 = OrmBelongsToManyModelStub() @@ -143,102 +174,106 @@ def test_attach_inserts_pivot_table_record(self): flexmock(BelongsToMany, touch_if_touching=lambda: True) relation = self._get_relation() query = flexmock() - query.should_receive('from_').once().with_args('user_role').and_return(query) - query.should_receive('insert').once().with_args([{'user_id': 1, 'role_id': 2, 'foo': 'bar'}]).and_return(True) + query.should_receive("from_").once().with_args("user_role").and_return(query) + query.should_receive("insert").once().with_args( + [{"user_id": 1, "role_id": 2, "foo": "bar"}] + ).and_return(True) mock_query_builder = flexmock() - relation.get_query().should_receive('get_query').and_return(mock_query_builder) - mock_query_builder.should_receive('new_query').once().and_return(query) - relation.should_receive('touch_if_touching').once() + relation.get_query().should_receive("get_query").and_return(mock_query_builder) + mock_query_builder.should_receive("new_query").once().and_return(query) + relation.should_receive("touch_if_touching").once() - relation.attach(2, {'foo': 'bar'}) + relation.attach(2, {"foo": "bar"}) def test_attach_multiple_inserts_pivot_table_record(self): flexmock(BelongsToMany, touch_if_touching=lambda: True) relation = self._get_relation() query = flexmock() - query.should_receive('from_').once().with_args('user_role').and_return(query) - query.should_receive('insert').once().with_args( + query.should_receive("from_").once().with_args("user_role").and_return(query) + query.should_receive("insert").once().with_args( [ - {'user_id': 1, 'role_id': 2, 'foo': 'bar'}, - {'user_id': 1, 'role_id': 3, 'bar': 'baz', 'foo': 'bar'} + {"user_id": 1, "role_id": 2, "foo": "bar"}, + {"user_id": 1, "role_id": 3, "bar": "baz", "foo": "bar"}, ] ).and_return(True) mock_query_builder = flexmock() - relation.get_query().should_receive('get_query').and_return(mock_query_builder) - mock_query_builder.should_receive('new_query').once().and_return(query) - relation.should_receive('touch_if_touching').once() + relation.get_query().should_receive("get_query").and_return(mock_query_builder) + mock_query_builder.should_receive("new_query").once().and_return(query) + relation.should_receive("touch_if_touching").once() - relation.attach([2, {3: {'bar': 'baz'}}], {'foo': 'bar'}) + relation.attach([2, {3: {"bar": "baz"}}], {"foo": "bar"}) def test_attach_inserts_pivot_table_records_with_timestamps_when_ncessary(self): flexmock(BelongsToMany, touch_if_touching=lambda: True) relation = self._get_relation().with_timestamps() query = flexmock() - query.should_receive('from_').once().with_args('user_role').and_return(query) + query.should_receive("from_").once().with_args("user_role").and_return(query) now = pendulum.now() - query.should_receive('insert').once().with_args( + query.should_receive("insert").once().with_args( [ - {'user_id': 1, 'role_id': 2, 'foo': 'bar', 'created_at': now, 'updated_at': now} + { + "user_id": 1, + "role_id": 2, + "foo": "bar", + "created_at": now, + "updated_at": now, + } ] ).and_return(True) mock_query_builder = flexmock() - relation.get_query().should_receive('get_query').and_return(mock_query_builder) - mock_query_builder.should_receive('new_query').once().and_return(query) - relation.get_parent().should_receive('fresh_timestamp').once().and_return(now) - relation.should_receive('touch_if_touching').once() + relation.get_query().should_receive("get_query").and_return(mock_query_builder) + mock_query_builder.should_receive("new_query").once().and_return(query) + relation.get_parent().should_receive("fresh_timestamp").once().and_return(now) + relation.should_receive("touch_if_touching").once() - relation.attach(2, {'foo': 'bar'}) + relation.attach(2, {"foo": "bar"}) def test_attach_inserts_pivot_table_records_with_a_created_at_timestamp(self): flexmock(BelongsToMany, touch_if_touching=lambda: True) - relation = self._get_relation().with_pivot('created_at') + relation = self._get_relation().with_pivot("created_at") query = flexmock() - query.should_receive('from_').once().with_args('user_role').and_return(query) + query.should_receive("from_").once().with_args("user_role").and_return(query) now = pendulum.now() - query.should_receive('insert').once().with_args( - [ - {'user_id': 1, 'role_id': 2, 'foo': 'bar', 'created_at': now} - ] + query.should_receive("insert").once().with_args( + [{"user_id": 1, "role_id": 2, "foo": "bar", "created_at": now}] ).and_return(True) mock_query_builder = flexmock() - relation.get_query().should_receive('get_query').and_return(mock_query_builder) - mock_query_builder.should_receive('new_query').once().and_return(query) - relation.get_parent().should_receive('fresh_timestamp').once().and_return(now) - relation.should_receive('touch_if_touching').once() + relation.get_query().should_receive("get_query").and_return(mock_query_builder) + mock_query_builder.should_receive("new_query").once().and_return(query) + relation.get_parent().should_receive("fresh_timestamp").once().and_return(now) + relation.should_receive("touch_if_touching").once() - relation.attach(2, {'foo': 'bar'}) + relation.attach(2, {"foo": "bar"}) def test_attach_inserts_pivot_table_records_with_an_updated_at_timestamp(self): flexmock(BelongsToMany, touch_if_touching=lambda: True) - relation = self._get_relation().with_pivot('updated_at') + relation = self._get_relation().with_pivot("updated_at") query = flexmock() - query.should_receive('from_').once().with_args('user_role').and_return(query) + query.should_receive("from_").once().with_args("user_role").and_return(query) now = pendulum.now() - query.should_receive('insert').once().with_args( - [ - {'user_id': 1, 'role_id': 2, 'foo': 'bar', 'updated_at': now} - ] + query.should_receive("insert").once().with_args( + [{"user_id": 1, "role_id": 2, "foo": "bar", "updated_at": now}] ).and_return(True) mock_query_builder = flexmock() - relation.get_query().should_receive('get_query').and_return(mock_query_builder) - mock_query_builder.should_receive('new_query').once().and_return(query) - relation.get_parent().should_receive('fresh_timestamp').once().and_return(now) - relation.should_receive('touch_if_touching').once() + relation.get_query().should_receive("get_query").and_return(mock_query_builder) + mock_query_builder.should_receive("new_query").once().and_return(query) + relation.get_parent().should_receive("fresh_timestamp").once().and_return(now) + relation.should_receive("touch_if_touching").once() - relation.attach(2, {'foo': 'bar'}) + relation.attach(2, {"foo": "bar"}) def test_detach_remove_pivot_table_record(self): flexmock(BelongsToMany, touch_if_touching=lambda: True) relation = self._get_relation() query = flexmock() - query.should_receive('from_').once().with_args('user_role').and_return(query) - query.should_receive('where').once().with_args('user_id', 1).and_return(query) - query.should_receive('where_in').once().with_args('role_id', [1, 2, 3]) - query.should_receive('delete').once().and_return(True) + query.should_receive("from_").once().with_args("user_role").and_return(query) + query.should_receive("where").once().with_args("user_id", 1).and_return(query) + query.should_receive("where_in").once().with_args("role_id", [1, 2, 3]) + query.should_receive("delete").once().and_return(True) mock_query_builder = flexmock() - relation.get_query().should_receive('get_query').and_return(mock_query_builder) - mock_query_builder.should_receive('new_query').once().and_return(query) - relation.should_receive('touch_if_touching').once() + relation.get_query().should_receive("get_query").and_return(mock_query_builder) + mock_query_builder.should_receive("new_query").once().and_return(query) + relation.should_receive("touch_if_touching").once() self.assertTrue(relation.detach([1, 2, 3])) @@ -246,14 +281,14 @@ def test_detach_with_single_id_remove_pivot_table_record(self): flexmock(BelongsToMany, touch_if_touching=lambda: True) relation = self._get_relation() query = flexmock() - query.should_receive('from_').once().with_args('user_role').and_return(query) - query.should_receive('where').once().with_args('user_id', 1).and_return(query) - query.should_receive('where_in').once().with_args('role_id', [1]) - query.should_receive('delete').once().and_return(True) + query.should_receive("from_").once().with_args("user_role").and_return(query) + query.should_receive("where").once().with_args("user_id", 1).and_return(query) + query.should_receive("where_in").once().with_args("role_id", [1]) + query.should_receive("delete").once().and_return(True) mock_query_builder = flexmock() - relation.get_query().should_receive('get_query').and_return(mock_query_builder) - mock_query_builder.should_receive('new_query').once().and_return(query) - relation.should_receive('touch_if_touching').once() + relation.get_query().should_receive("get_query").and_return(mock_query_builder) + mock_query_builder.should_receive("new_query").once().and_return(query) + relation.should_receive("touch_if_touching").once() self.assertTrue(relation.detach(1)) @@ -261,14 +296,14 @@ def test_detach_clears_all_records_when_no_ids(self): flexmock(BelongsToMany, touch_if_touching=lambda: True) relation = self._get_relation() query = flexmock() - query.should_receive('from_').once().with_args('user_role').and_return(query) - query.should_receive('where').once().with_args('user_id', 1).and_return(query) - query.should_receive('where_in').never() - query.should_receive('delete').once().and_return(True) + query.should_receive("from_").once().with_args("user_role").and_return(query) + query.should_receive("where").once().with_args("user_id", 1).and_return(query) + query.should_receive("where_in").never() + query.should_receive("delete").once().and_return(True) mock_query_builder = flexmock() - relation.get_query().should_receive('get_query').and_return(mock_query_builder) - mock_query_builder.should_receive('new_query').once().and_return(query) - relation.should_receive('touch_if_touching').once() + relation.get_query().should_receive("get_query").and_return(mock_query_builder) + mock_query_builder.should_receive("new_query").once().and_return(query) + relation.should_receive("touch_if_touching").once() self.assertTrue(relation.detach()) @@ -276,201 +311,240 @@ def test_create_creates_new_model_and_insert_attachment_record(self): flexmock(BelongsToMany, attach=lambda: True) relation = self._get_relation() model = flexmock() - relation.get_related().should_receive('new_instance').once().and_return(model).with_args({'foo': 'bar'}) - model.should_receive('save').once() - model.should_receive('get_key').and_return('foo') - relation.should_receive('attach').once().with_args('foo', {'bar': 'baz'}, True) + relation.get_related().should_receive("new_instance").once().and_return( + model + ).with_args({"foo": "bar"}) + model.should_receive("save").once() + model.should_receive("get_key").and_return("foo") + relation.should_receive("attach").once().with_args("foo", {"bar": "baz"}, True) - self.assertEqual(model, relation.create({'foo': 'bar'}, {'bar': 'baz'})) + self.assertEqual(model, relation.create({"foo": "bar"}, {"bar": "baz"})) def test_find_or_new_finds_model(self): flexmock(BelongsToMany) relation = self._get_relation() model = flexmock() - model.foo = 'bar' - relation.get_query().should_receive('find').with_args('foo', None).and_return(model) - relation.get_related().should_receive('new_instance').never() + model.foo = "bar" + relation.get_query().should_receive("find").with_args("foo", None).and_return( + model + ) + relation.get_related().should_receive("new_instance").never() - self.assertEqual('bar', relation.find_or_new('foo').foo) + self.assertEqual("bar", relation.find_or_new("foo").foo) def test_find_or_new_returns_new_model(self): flexmock(BelongsToMany) relation = self._get_relation() model = flexmock() - model.foo = 'bar' - relation.get_query().should_receive('find').with_args('foo', None).and_return(None) - relation.get_related().should_receive('new_instance').once().and_return(model) + model.foo = "bar" + relation.get_query().should_receive("find").with_args("foo", None).and_return( + None + ) + relation.get_related().should_receive("new_instance").once().and_return(model) - self.assertEqual('bar', relation.find_or_new('foo').foo) + self.assertEqual("bar", relation.find_or_new("foo").foo) def test_first_or_new_finds_first_model(self): flexmock(BelongsToMany) relation = self._get_relation() model = flexmock() - model.foo = 'bar' - relation.get_query().should_receive('where').with_args({'foo': 'bar'}).and_return(relation.get_query()) - relation.get_query().should_receive('first').once().and_return(model) - relation.get_related().should_receive('new_instance').never() + model.foo = "bar" + relation.get_query().should_receive("where").with_args( + {"foo": "bar"} + ).and_return(relation.get_query()) + relation.get_query().should_receive("first").once().and_return(model) + relation.get_related().should_receive("new_instance").never() - self.assertEqual('bar', relation.first_or_new({'foo': 'bar'}).foo) + self.assertEqual("bar", relation.first_or_new({"foo": "bar"}).foo) def test_first_or_new_returns_new_model(self): flexmock(BelongsToMany) relation = self._get_relation() model = flexmock() - model.foo = 'bar' - relation.get_query().should_receive('where').with_args({'foo': 'bar'}).and_return(relation.get_query()) - relation.get_query().should_receive('first').once().and_return(None) - relation.get_related().should_receive('new_instance').once().and_return(model) + model.foo = "bar" + relation.get_query().should_receive("where").with_args( + {"foo": "bar"} + ).and_return(relation.get_query()) + relation.get_query().should_receive("first").once().and_return(None) + relation.get_related().should_receive("new_instance").once().and_return(model) - self.assertEqual('bar', relation.first_or_new({'foo': 'bar'}).foo) + self.assertEqual("bar", relation.first_or_new({"foo": "bar"}).foo) def test_first_or_create_finds_first_model(self): flexmock(BelongsToMany) relation = self._get_relation() model = flexmock() - model.foo = 'bar' - relation.get_query().should_receive('where').with_args({'foo': 'bar'}).and_return(relation.get_query()) - relation.get_query().should_receive('first').once().and_return(model) - relation.should_receive('create').never() + model.foo = "bar" + relation.get_query().should_receive("where").with_args( + {"foo": "bar"} + ).and_return(relation.get_query()) + relation.get_query().should_receive("first").once().and_return(model) + relation.should_receive("create").never() - self.assertEqual('bar', relation.first_or_create({'foo': 'bar'}).foo) + self.assertEqual("bar", relation.first_or_create({"foo": "bar"}).foo) def test_first_or_create_returns_new_model(self): flexmock(BelongsToMany) relation = self._get_relation() model = flexmock() - model.foo = 'bar' - relation.get_query().should_receive('where').with_args({'foo': 'bar'}).and_return(relation.get_query()) - relation.get_query().should_receive('first').once().and_return(None) - relation.should_receive('create').once().with_args({'foo': 'bar'}, {}, True).and_return(model) + model.foo = "bar" + relation.get_query().should_receive("where").with_args( + {"foo": "bar"} + ).and_return(relation.get_query()) + relation.get_query().should_receive("first").once().and_return(None) + relation.should_receive("create").once().with_args( + {"foo": "bar"}, {}, True + ).and_return(model) - self.assertEqual('bar', relation.first_or_create({'foo': 'bar'}).foo) + self.assertEqual("bar", relation.first_or_create({"foo": "bar"}).foo) def test_update_or_create_finds_first_mode_and_updates(self): flexmock(BelongsToMany) relation = self._get_relation() model = flexmock() - model.foo = 'bar' - relation.get_query().should_receive('where').with_args({'foo': 'bar'}).and_return(relation.get_query()) - relation.get_query().should_receive('first').once().and_return(model) - model.should_receive('fill').once() - model.should_receive('save').once() - relation.should_receive('create').never() + model.foo = "bar" + relation.get_query().should_receive("where").with_args( + {"foo": "bar"} + ).and_return(relation.get_query()) + relation.get_query().should_receive("first").once().and_return(model) + model.should_receive("fill").once() + model.should_receive("save").once() + relation.should_receive("create").never() - self.assertEqual('bar', relation.update_or_create({'foo': 'bar'}).foo) + self.assertEqual("bar", relation.update_or_create({"foo": "bar"}).foo) def test_update_or_create_returns_new_model(self): flexmock(BelongsToMany) relation = self._get_relation() model = flexmock() - model.foo = 'bar' - relation.get_query().should_receive('where').with_args({'foo': 'bar'}).and_return(relation.get_query()) - relation.get_query().should_receive('first').once().and_return(None) - relation.should_receive('create').once().with_args({'bar': 'baz'}, None, True).and_return(model) + model.foo = "bar" + relation.get_query().should_receive("where").with_args( + {"foo": "bar"} + ).and_return(relation.get_query()) + relation.get_query().should_receive("first").once().and_return(None) + relation.should_receive("create").once().with_args( + {"bar": "baz"}, None, True + ).and_return(model) - self.assertEqual('bar', relation.update_or_create({'foo': 'bar'}, {'bar': 'baz'}).foo) + self.assertEqual( + "bar", relation.update_or_create({"foo": "bar"}, {"bar": "baz"}).foo + ) def test_sync_syncs_intermediate_table_with_given_list(self): - for list_ in [[2, 3, 4], ['2', '3', '4']]: + for list_ in [[2, 3, 4], ["2", "3", "4"]]: flexmock(BelongsToMany) relation = self._get_relation() query = flexmock() - query.should_receive('from_').once().with_args('user_role').and_return(query) - query.should_receive('where').once().with_args('user_id', 1).and_return(query) + query.should_receive("from_").once().with_args("user_role").and_return( + query + ) + query.should_receive("where").once().with_args("user_id", 1).and_return( + query + ) mock_query_builder = flexmock() - relation.get_query().should_receive('get_query').and_return(mock_query_builder) - mock_query_builder.should_receive('new_query').once().and_return(query) - query.should_receive('lists').once().with_args('role_id').and_return(Collection([1, list_[0], list_[1]])) - relation.should_receive('attach').once().with_args(list_[2], {}, False) - relation.should_receive('detach').once().with_args([1]) - relation.get_related().should_receive('touches').and_return(False) - relation.get_parent().should_receive('touches').and_return(False) + relation.get_query().should_receive("get_query").and_return( + mock_query_builder + ) + mock_query_builder.should_receive("new_query").once().and_return(query) + query.should_receive("lists").once().with_args("role_id").and_return( + Collection([1, list_[0], list_[1]]) + ) + relation.should_receive("attach").once().with_args(list_[2], {}, False) + relation.should_receive("detach").once().with_args([1]) + relation.get_related().should_receive("touches").and_return(False) + relation.get_parent().should_receive("touches").and_return(False) self.assertEqual( - { - 'attached': [list_[2]], - 'detached': [1], - 'updated': [] - }, - relation.sync(list_) + {"attached": [list_[2]], "detached": [1], "updated": []}, + relation.sync(list_), ) def test_sync_syncs_intermediate_table_with_given_list_and_attributes(self): flexmock(BelongsToMany) relation = self._get_relation() query = flexmock() - query.should_receive('from_').once().with_args('user_role').and_return(query) - query.should_receive('where').once().with_args('user_id', 1).and_return(query) + query.should_receive("from_").once().with_args("user_role").and_return(query) + query.should_receive("where").once().with_args("user_id", 1).and_return(query) mock_query_builder = flexmock() - relation.get_query().should_receive('get_query').and_return(mock_query_builder) - mock_query_builder.should_receive('new_query').once().and_return(query) - query.should_receive('lists').once().with_args('role_id').and_return(Collection([1, 2, 3])) - relation.should_receive('attach').once().with_args(4, {'foo': 'bar'}, False) - relation.should_receive('update_existing_pivot').once().with_args(3, {'bar': 'baz'}, False).and_return(True) - relation.should_receive('detach').once().with_args([1]) - relation.should_receive('touch_if_touching').once() - relation.get_related().should_receive('touches').and_return(False) - relation.get_parent().should_receive('touches').and_return(False) + relation.get_query().should_receive("get_query").and_return(mock_query_builder) + mock_query_builder.should_receive("new_query").once().and_return(query) + query.should_receive("lists").once().with_args("role_id").and_return( + Collection([1, 2, 3]) + ) + relation.should_receive("attach").once().with_args(4, {"foo": "bar"}, False) + relation.should_receive("update_existing_pivot").once().with_args( + 3, {"bar": "baz"}, False + ).and_return(True) + relation.should_receive("detach").once().with_args([1]) + relation.should_receive("touch_if_touching").once() + relation.get_related().should_receive("touches").and_return(False) + relation.get_parent().should_receive("touches").and_return(False) self.assertEqual( - { - 'attached': [4], - 'detached': [1], - 'updated': [3] - }, - relation.sync([2, {3: {'bar': 'baz'}}, {4: {'foo': 'bar'}}], ) + {"attached": [4], "detached": [1], "updated": [3]}, + relation.sync([2, {3: {"bar": "baz"}}, {4: {"foo": "bar"}}]), ) def test_sync_does_not_return_values_that_were_not_updated(self): flexmock(BelongsToMany) relation = self._get_relation() query = flexmock() - query.should_receive('from_').once().with_args('user_role').and_return(query) - query.should_receive('where').once().with_args('user_id', 1).and_return(query) + query.should_receive("from_").once().with_args("user_role").and_return(query) + query.should_receive("where").once().with_args("user_id", 1).and_return(query) mock_query_builder = flexmock() - relation.get_query().should_receive('get_query').and_return(mock_query_builder) - mock_query_builder.should_receive('new_query').once().and_return(query) - query.should_receive('lists').once().with_args('role_id').and_return(Collection([1, 2, 3])) - relation.should_receive('attach').once().with_args(4, {'foo': 'bar'}, False) - relation.should_receive('update_existing_pivot').once().with_args(3, {'bar': 'baz'}, False).and_return(False) - relation.should_receive('detach').once().with_args([1]) - relation.should_receive('touch_if_touching').once() - relation.get_related().should_receive('touches').and_return(False) - relation.get_parent().should_receive('touches').and_return(False) + relation.get_query().should_receive("get_query").and_return(mock_query_builder) + mock_query_builder.should_receive("new_query").once().and_return(query) + query.should_receive("lists").once().with_args("role_id").and_return( + Collection([1, 2, 3]) + ) + relation.should_receive("attach").once().with_args(4, {"foo": "bar"}, False) + relation.should_receive("update_existing_pivot").once().with_args( + 3, {"bar": "baz"}, False + ).and_return(False) + relation.should_receive("detach").once().with_args([1]) + relation.should_receive("touch_if_touching").once() + relation.get_related().should_receive("touches").and_return(False) + relation.get_parent().should_receive("touches").and_return(False) self.assertEqual( - { - 'attached': [4], - 'detached': [1], - 'updated': [] - }, - relation.sync([2, {3: {'bar': 'baz'}}, {4: {'foo': 'bar'}}], ) + {"attached": [4], "detached": [1], "updated": []}, + relation.sync([2, {3: {"bar": "baz"}}, {4: {"foo": "bar"}}]), ) def test_touch_method_syncs_timestamps(self): relation = self._get_relation() - relation.get_related().should_receive('get_updated_at_column').and_return('updated_at') + relation.get_related().should_receive("get_updated_at_column").and_return( + "updated_at" + ) now = pendulum.now() - relation.get_related().should_receive('fresh_timestamp').and_return(now) - relation.get_related().should_receive('get_qualified_key_name').and_return('table.id') - relation.get_query().get_query().should_receive('select').once().with_args('table.id')\ - .and_return(relation.get_query().get_query()) - relation.get_query().should_receive('lists').once().and_return(Collection([1, 2, 3])) + relation.get_related().should_receive("fresh_timestamp").and_return(now) + relation.get_related().should_receive("get_qualified_key_name").and_return( + "table.id" + ) + relation.get_query().get_query().should_receive("select").once().with_args( + "table.id" + ).and_return(relation.get_query().get_query()) + relation.get_query().should_receive("lists").once().and_return( + Collection([1, 2, 3]) + ) query = flexmock() - relation.get_related().should_receive('new_query').once().and_return(query) - query.should_receive('where_in').once().with_args('id', [1, 2, 3]).and_return(query) - query.should_receive('update').once().with_args({'updated_at': now}) + relation.get_related().should_receive("new_query").once().and_return(query) + query.should_receive("where_in").once().with_args("id", [1, 2, 3]).and_return( + query + ) + query.should_receive("update").once().with_args({"updated_at": now}) relation.touch() def test_touch_if_touching(self): flexmock(BelongsToMany) relation = self._get_relation() - relation.should_receive('_touching_parent').once().and_return(True) - relation.get_parent().should_receive('touch').once() - relation.get_parent().should_receive('touches').once().with_args('relation_name').and_return(True) - relation.should_receive('touch').once() + relation.should_receive("_touching_parent").once().and_return(True) + relation.get_parent().should_receive("touch").once() + relation.get_parent().should_receive("touches").once().with_args( + "relation_name" + ).and_return(True) + relation.should_receive("touch").once() relation.touch_if_touching() @@ -478,16 +552,20 @@ def test_sync_method_converts_collection_to_list_of_keys(self): flexmock(BelongsToMany) relation = self._get_relation() query = flexmock() - query.should_receive('from_').once().with_args('user_role').and_return(query) - query.should_receive('where').once().with_args('user_id', 1).and_return(query) + query.should_receive("from_").once().with_args("user_role").and_return(query) + query.should_receive("where").once().with_args("user_id", 1).and_return(query) mock_query_builder = flexmock() - relation.get_query().should_receive('get_query').and_return(mock_query_builder) - mock_query_builder.should_receive('new_query').once().and_return(query) - query.should_receive('lists').once().with_args('role_id').and_return(Collection([1, 2, 3])) + relation.get_query().should_receive("get_query").and_return(mock_query_builder) + mock_query_builder.should_receive("new_query").once().and_return(query) + query.should_receive("lists").once().with_args("role_id").and_return( + Collection([1, 2, 3]) + ) collection = flexmock(Collection()) - collection.should_receive('model_keys').once().and_return([1, 2, 3]) - relation.should_receive('_format_sync_list').with_args([1, 2, 3]).and_return({1: {}, 2: {}, 3: {}}) + collection.should_receive("model_keys").once().and_return([1, 2, 3]) + relation.should_receive("_format_sync_list").with_args([1, 2, 3]).and_return( + {1: {}, 2: {}, 3: {}} + ) relation.sync(collection) @@ -495,54 +573,70 @@ def test_where_pivot_params_used_for_new_queries(self): flexmock(BelongsToMany) relation = self._get_relation() - relation.get_query().should_receive('where').once().and_return(relation) + relation.get_query().should_receive("where").once().and_return(relation) query = flexmock() mock_query_builder = flexmock() - relation.get_query().should_receive('get_query').and_return(mock_query_builder) - mock_query_builder.should_receive('new_query').once().and_return(query) + relation.get_query().should_receive("get_query").and_return(mock_query_builder) + mock_query_builder.should_receive("new_query").once().and_return(query) - query.should_receive('from_').once().with_args('user_role').and_return(query) + query.should_receive("from_").once().with_args("user_role").and_return(query) - query.should_receive('where').once().with_args('user_id', 1).and_return(query) + query.should_receive("where").once().with_args("user_id", 1).and_return(query) - query.should_receive('where').once().with_args('foo', '=', 'bar', 'and').and_return(query) + query.should_receive("where").once().with_args( + "foo", "=", "bar", "and" + ).and_return(query) - query.should_receive('lists').once().with_args('role_id').and_return(Collection([1, 2, 3])) - relation.should_receive('_format_sync_list').with_args([1, 2, 3]).and_return({1: {}, 2: {}, 3: {}}) + query.should_receive("lists").once().with_args("role_id").and_return( + Collection([1, 2, 3]) + ) + relation.should_receive("_format_sync_list").with_args([1, 2, 3]).and_return( + {1: {}, 2: {}, 3: {}} + ) - relation = relation.where_pivot('foo', '=', 'bar') + relation = relation.where_pivot("foo", "=", "bar") relation.sync([1, 2, 3]) def _get_relation(self): builder, parent = self._get_relation_arguments()[:2] - return BelongsToMany(builder, parent, 'user_role', 'user_id', 'role_id', 'relation_name') + return BelongsToMany( + builder, parent, "user_role", "user_id", "role_id", "relation_name" + ) def _get_relation_arguments(self): - flexmock(Model).should_receive('_boot_columns').and_return(['name']) + flexmock(Model).should_receive("_boot_columns").and_return(["name"]) parent = flexmock(Model()) - parent.should_receive('get_key').and_return(1) - parent.should_receive('get_created_at_column').and_return('created_at') - parent.should_receive('get_updated_at_column').and_return('updated_at') + parent.should_receive("get_key").and_return(1) + parent.should_receive("get_created_at_column").and_return("created_at") + parent.should_receive("get_updated_at_column").and_return("updated_at") - query = flexmock(QueryBuilder(MockConnection().prepare_mock(), QueryGrammar(), QueryProcessor())) + query = flexmock( + QueryBuilder( + MockConnection().prepare_mock(), QueryGrammar(), QueryProcessor() + ) + ) flexmock(Builder) builder = Builder(query) - builder.should_receive('get_query').and_return(query) + builder.should_receive("get_query").and_return(query) related = flexmock(Model()) builder.set_model(related) - builder.should_receive('get_model').and_return(related) + builder.should_receive("get_model").and_return(related) - related.should_receive('new_query').and_return(builder) - related.should_receive('get_key_name').and_return('id') - related.should_receive('get_table').and_return('roles') - related.should_receive('new_pivot').replace_with(lambda *args: Pivot(*args)) + related.should_receive("new_query").and_return(builder) + related.should_receive("get_key_name").and_return("id") + related.should_receive("get_table").and_return("roles") + related.should_receive("new_pivot").replace_with(lambda *args: Pivot(*args)) - builder.get_query().should_receive('join').at_least().once().with_args('user_role', 'roles.id', '=', 'user_role.role_id') - builder.should_receive('where').at_least().once().with_args('user_role.user_id', '=', 1) + builder.get_query().should_receive("join").at_least().once().with_args( + "user_role", "roles.id", "=", "user_role.role_id" + ) + builder.should_receive("where").at_least().once().with_args( + "user_role.user_id", "=", 1 + ) - return builder, parent, 'user_role', 'user_id', 'role_id', 'relation_id' + return builder, parent, "user_role", "user_id", "role_id", "relation_id" class OrmBelongsToManyModelStub(Model): diff --git a/tests/orm/relations/test_decorators.py b/tests/orm/relations/test_decorators.py index 20f62829..4e33c917 100644 --- a/tests/orm/relations/test_decorators.py +++ b/tests/orm/relations/test_decorators.py @@ -2,14 +2,20 @@ from ... import OratorTestCase from orator import Model as BaseModel -from orator.orm import morph_to, has_one, has_many, belongs_to_many, morph_many, belongs_to +from orator.orm import ( + morph_to, + has_one, + has_many, + belongs_to_many, + morph_many, + belongs_to, +) from orator.orm.model import ModelRegister from orator.connections import SQLiteConnection from orator.connectors.sqlite_connector import SQLiteConnector class DecoratorsTestCase(OratorTestCase): - @classmethod def setUpClass(cls): Model.set_connection_resolver(DatabaseIntegrationConnectionResolver()) @@ -19,41 +25,41 @@ def tearDownClass(cls): Model.unset_connection_resolver() def setUp(self): - with self.schema().create('test_users') as table: - table.increments('id') - table.string('email').unique() + with self.schema().create("test_users") as table: + table.increments("id") + table.string("email").unique() table.timestamps() - with self.schema().create('test_friends') as table: - table.increments('id') - table.integer('user_id') - table.integer('friend_id') + with self.schema().create("test_friends") as table: + table.increments("id") + table.integer("user_id") + table.integer("friend_id") - with self.schema().create('test_posts') as table: - table.increments('id') - table.integer('user_id') - table.string('name') + with self.schema().create("test_posts") as table: + table.increments("id") + table.integer("user_id") + table.string("name") table.timestamps() table.soft_deletes() - with self.schema().create('test_photos') as table: - table.increments('id') - table.morphs('imageable') - table.string('name') + with self.schema().create("test_photos") as table: + table.increments("id") + table.morphs("imageable") + table.string("name") table.timestamps() def tearDown(self): - self.schema().drop('test_users') - self.schema().drop('test_friends') - self.schema().drop('test_posts') - self.schema().drop('test_photos') + self.schema().drop("test_users") + self.schema().drop("test_friends") + self.schema().drop("test_posts") + self.schema().drop("test_photos") def test_extra_queries_are_properly_set_on_relations(self): self.create() # With eager loading - user = OratorTestUser.with_('friends', 'posts', 'post', 'photos').find(1) - post = OratorTestPost.with_('user', 'photos').find(1) + user = OratorTestUser.with_("friends", "posts", "post", "photos").find(1) + post = OratorTestPost.with_("user", "photos").find(1) self.assertEqual(1, len(user.friends)) self.assertEqual(2, len(user.posts)) self.assertIsInstance(user.post, OratorTestPost) @@ -63,27 +69,27 @@ def test_extra_queries_are_properly_set_on_relations(self): self.assertEqual( 'SELECT * FROM "test_users" INNER JOIN "test_friends" ON "test_users"."id" = "test_friends"."friend_id" ' 'WHERE "test_friends"."user_id" = ? ORDER BY "friend_id" ASC', - user.friends().to_sql() + user.friends().to_sql(), ) self.assertEqual( 'SELECT * FROM "test_posts" WHERE "deleted_at" IS NULL AND "test_posts"."user_id" = ?', - user.posts().to_sql() + user.posts().to_sql(), ) self.assertEqual( 'SELECT * FROM "test_posts" WHERE "test_posts"."user_id" = ? ORDER BY "name" DESC', - user.post().to_sql() + user.post().to_sql(), ) self.assertEqual( 'SELECT * FROM "test_photos" WHERE "name" IS NOT NULL AND "test_photos"."imageable_id" = ? AND "test_photos"."imageable_type" = ?', - user.photos().to_sql() + user.photos().to_sql(), ) self.assertEqual( 'SELECT * FROM "test_users" WHERE "test_users"."id" = ? ORDER BY "id" ASC', - post.user().to_sql() + post.user().to_sql(), ) self.assertEqual( 'SELECT * FROM "test_photos" WHERE "test_photos"."imageable_id" = ? AND "test_photos"."imageable_type" = ?', - post.photos().to_sql() + post.photos().to_sql(), ) # Without eager loading @@ -98,45 +104,48 @@ def test_extra_queries_are_properly_set_on_relations(self): self.assertEqual( 'SELECT * FROM "test_users" INNER JOIN "test_friends" ON "test_users"."id" = "test_friends"."friend_id" ' 'WHERE "test_friends"."user_id" = ? ORDER BY "friend_id" ASC', - user.friends().to_sql() + user.friends().to_sql(), ) self.assertEqual( 'SELECT * FROM "test_posts" WHERE "deleted_at" IS NULL AND "test_posts"."user_id" = ?', - user.posts().to_sql() + user.posts().to_sql(), ) self.assertEqual( 'SELECT * FROM "test_posts" WHERE "test_posts"."user_id" = ? ORDER BY "name" DESC', - user.post().to_sql() + user.post().to_sql(), ) self.assertEqual( 'SELECT * FROM "test_photos" WHERE "name" IS NOT NULL AND "test_photos"."imageable_id" = ? AND "test_photos"."imageable_type" = ?', - user.photos().to_sql() + user.photos().to_sql(), ) self.assertEqual( 'SELECT * FROM "test_users" WHERE "test_users"."id" = ? ORDER BY "id" ASC', - post.user().to_sql() + post.user().to_sql(), ) self.assertEqual( 'SELECT * FROM "test_photos" WHERE "test_photos"."imageable_id" = ? AND "test_photos"."imageable_type" = ?', - post.photos().to_sql() + post.photos().to_sql(), ) + self.assertEqual( + 'SELECT DISTINCT * FROM "test_posts" WHERE "deleted_at" IS NULL AND "test_posts"."user_id" = ? ORDER BY "user_id" ASC', + user.posts().order_by("user_id").distinct().to_sql(), + ) def create(self): - user = OratorTestUser.create(id=1, email='john@doe.com') - friend = OratorTestUser.create(id=2, email='jane@doe.com') + user = OratorTestUser.create(id=1, email="john@doe.com") + friend = OratorTestUser.create(id=2, email="jane@doe.com") user.friends().attach(friend) - post1 = user.posts().create(name='First Post') - post2 = user.posts().create(name='Second Post') + post1 = user.posts().create(name="First Post") + post2 = user.posts().create(name="Second Post") - user.photos().create(name='Avatar 1') - user.photos().create(name='Avatar 2') - user.photos().create(name='Avatar 3') - - post1.photos().create(name='Hero 1') - post1.photos().create(name='Hero 2') + user.photos().create(name="Avatar 1") + user.photos().create(name="Avatar 2") + user.photos().create(name="Avatar 3") + post1.photos().create(name="Hero 1") + post1.photos().create(name="Hero 2") def connection(self): return Model.get_connection_resolver().connection() @@ -152,43 +161,43 @@ class Model(BaseModel): class OratorTestUser(Model): - __table__ = 'test_users' + __table__ = "test_users" __guarded__ = [] - @belongs_to_many('test_friends', 'user_id', 'friend_id', with_pivot=['id']) + @belongs_to_many("test_friends", "user_id", "friend_id", with_pivot=["id"]) def friends(self): - return OratorTestUser.order_by('friend_id') + return OratorTestUser.order_by("friend_id") - @has_many('user_id') + @has_many("user_id") def posts(self): - return OratorTestPost.where_null('deleted_at') + return OratorTestPost.where_null("deleted_at") - @has_one('user_id') + @has_one("user_id") def post(self): - return OratorTestPost.order_by('name', 'desc') + return OratorTestPost.order_by("name", "desc") - @morph_many('imageable') + @morph_many("imageable") def photos(self): - return OratorTestPhoto.where_not_null('name') + return OratorTestPhoto.where_not_null("name") class OratorTestPost(Model): - __table__ = 'test_posts' + __table__ = "test_posts" __guarded__ = [] - @belongs_to('user_id') + @belongs_to("user_id") def user(self): - return OratorTestUser.order_by('id') + return OratorTestUser.order_by("id") - @morph_many('imageable') + @morph_many("imageable") def photos(self): - return 'test_photos' + return "test_photos" class OratorTestPhoto(Model): - __table__ = 'test_photos' + __table__ = "test_photos" __guarded__ = [] @morph_to @@ -204,12 +213,14 @@ def connection(self, name=None): if self._connection: return self._connection - self._connection = SQLiteConnection(SQLiteConnector().connect({'database': ':memory:'})) + self._connection = SQLiteConnection( + SQLiteConnector().connect({"database": ":memory:"}) + ) return self._connection def get_default_connection(self): - return 'default' + return "default" def set_default_connection(self, name): pass diff --git a/tests/orm/relations/test_has_many.py b/tests/orm/relations/test_has_many.py index e25b3879..16a1e0ea 100644 --- a/tests/orm/relations/test_has_many.py +++ b/tests/orm/relations/test_has_many.py @@ -15,131 +15,168 @@ class OrmHasManyTestCase(OratorTestCase): - def tearDown(self): flexmock_teardown() def test_create_properly_creates_new_model(self): relation = self._get_relation() created = flexmock(Model(), save=lambda: True, set_attribute=lambda: None) - created.should_receive('save').once().and_return(True) - relation.get_related().should_receive('new_instance').once().with_args({'name': 'john'}).and_return(created) - created.should_receive('set_attribute').with_args('foreign_key', 1) + created.should_receive("save").once().and_return(True) + relation.get_related().should_receive("new_instance").once().with_args( + {"name": "john"} + ).and_return(created) + created.should_receive("set_attribute").with_args("foreign_key", 1) - self.assertEqual(created, relation.create(name='john')) + self.assertEqual(created, relation.create(name="john")) def test_find_or_new_finds_model(self): relation = self._get_relation() model = flexmock() - model.foo = 'bar' - relation.get_query().should_receive('find').once().with_args('foo', ['*']).and_return(model) - model.should_receive('set_attribute').never() + model.foo = "bar" + relation.get_query().should_receive("find").once().with_args( + "foo", ["*"] + ).and_return(model) + model.should_receive("set_attribute").never() - self.assertEqual('bar', relation.find_or_new('foo').foo) + self.assertEqual("bar", relation.find_or_new("foo").foo) def test_find_or_new_returns_new_model_with_foreign_key_set(self): relation = self._get_relation() - relation.get_query().should_receive('find').once().with_args('foo', ['*']).and_return(None) + relation.get_query().should_receive("find").once().with_args( + "foo", ["*"] + ).and_return(None) model = flexmock() - model.foo = 'bar' - relation.get_related().should_receive('new_instance').once().with_args().and_return(model) - model.should_receive('set_attribute').once().with_args('foreign_key', 1) + model.foo = "bar" + relation.get_related().should_receive( + "new_instance" + ).once().with_args().and_return(model) + model.should_receive("set_attribute").once().with_args("foreign_key", 1) - self.assertEqual('bar', relation.find_or_new('foo').foo) + self.assertEqual("bar", relation.find_or_new("foo").foo) def test_first_or_new_finds_first_model(self): relation = self._get_relation() - relation.get_query().should_receive('where').once().with_args({'foo': 'bar'}).and_return(relation.get_query()) + relation.get_query().should_receive("where").once().with_args( + {"foo": "bar"} + ).and_return(relation.get_query()) model = flexmock() - model.foo = 'bar' - relation.get_query().should_receive('first').once().with_args().and_return(model) - model.should_receive('set_attribute').never() + model.foo = "bar" + relation.get_query().should_receive("first").once().with_args().and_return( + model + ) + model.should_receive("set_attribute").never() - self.assertEqual('bar', relation.first_or_new(foo='bar').foo) + self.assertEqual("bar", relation.first_or_new(foo="bar").foo) def test_first_or_new_returns_new_model_with_foreign_key_set(self): relation = self._get_relation() - relation.get_query().should_receive('where').once().with_args({'foo': 'bar'}).and_return(relation.get_query()) - relation.get_query().should_receive('first').once().with_args().and_return(None) + relation.get_query().should_receive("where").once().with_args( + {"foo": "bar"} + ).and_return(relation.get_query()) + relation.get_query().should_receive("first").once().with_args().and_return(None) model = flexmock() - model.foo = 'bar' - relation.get_related().should_receive('new_instance').once().with_args().and_return(model) - model.should_receive('set_attribute').once().with_args('foreign_key', 1) + model.foo = "bar" + relation.get_related().should_receive( + "new_instance" + ).once().with_args().and_return(model) + model.should_receive("set_attribute").once().with_args("foreign_key", 1) - self.assertEqual('bar', relation.first_or_new(foo='bar').foo) + self.assertEqual("bar", relation.first_or_new(foo="bar").foo) def test_first_or_create_finds_first_model(self): relation = self._get_relation() - relation.get_query().should_receive('where').once().with_args({'foo': 'bar'}).and_return(relation.get_query()) + relation.get_query().should_receive("where").once().with_args( + {"foo": "bar"} + ).and_return(relation.get_query()) model = flexmock() - model.foo = 'bar' - relation.get_query().should_receive('first').once().with_args().and_return(model) - model.should_receive('set_attribute').never() + model.foo = "bar" + relation.get_query().should_receive("first").once().with_args().and_return( + model + ) + model.should_receive("set_attribute").never() - self.assertEqual('bar', relation.first_or_create(foo='bar').foo) + self.assertEqual("bar", relation.first_or_create(foo="bar").foo) def test_first_or_create_returns_new_model_with_foreign_key_set(self): relation = self._get_relation() - relation.get_query().should_receive('where').once().with_args({'foo': 'bar'}).and_return(relation.get_query()) - relation.get_query().should_receive('first').once().with_args().and_return(None) + relation.get_query().should_receive("where").once().with_args( + {"foo": "bar"} + ).and_return(relation.get_query()) + relation.get_query().should_receive("first").once().with_args().and_return(None) model = flexmock() - model.foo = 'bar' - relation.get_related().should_receive('new_instance').once().with_args({'foo': 'bar'}).and_return(model) - model.should_receive('save').once().and_return(True) - model.should_receive('set_attribute').once().with_args('foreign_key', 1) + model.foo = "bar" + relation.get_related().should_receive("new_instance").once().with_args( + {"foo": "bar"} + ).and_return(model) + model.should_receive("save").once().and_return(True) + model.should_receive("set_attribute").once().with_args("foreign_key", 1) - self.assertEqual('bar', relation.first_or_create(foo='bar').foo) + self.assertEqual("bar", relation.first_or_create(foo="bar").foo) def test_update_or_create_finds_first_model_and_updates(self): relation = self._get_relation() - relation.get_query().should_receive('where').once().with_args({'foo': 'bar'}).and_return(relation.get_query()) + relation.get_query().should_receive("where").once().with_args( + {"foo": "bar"} + ).and_return(relation.get_query()) model = flexmock() - model.foo = 'bar' - relation.get_query().should_receive('first').once().with_args().and_return(model) - relation.get_related().should_receive('new_instance').never() - model.should_receive('fill').once().with_args({'foo': 'baz'}) - model.should_receive('save').once() - - self.assertEqual('bar', relation.update_or_create({'foo': 'bar'}, {'foo': 'baz'}).foo) + model.foo = "bar" + relation.get_query().should_receive("first").once().with_args().and_return( + model + ) + relation.get_related().should_receive("new_instance").never() + model.should_receive("fill").once().with_args({"foo": "baz"}) + model.should_receive("save").once() + + self.assertEqual( + "bar", relation.update_or_create({"foo": "bar"}, {"foo": "baz"}).foo + ) def test_update_or_create_creates_new_model_with_foreign_key_set(self): relation = self._get_relation() - relation.get_query().should_receive('where').once().with_args({'foo': 'bar'}).and_return(relation.get_query()) + relation.get_query().should_receive("where").once().with_args( + {"foo": "bar"} + ).and_return(relation.get_query()) - relation.get_query().should_receive('first').once().with_args().and_return(None) + relation.get_query().should_receive("first").once().with_args().and_return(None) model = flexmock() - model.foo = 'bar' - relation.get_related().should_receive('new_instance').once().and_return(model) - model.should_receive('fill').once().with_args({'foo': 'baz'}) - model.should_receive('save').once() - model.should_receive('set_attribute').once().with_args('foreign_key', 1) + model.foo = "bar" + relation.get_related().should_receive("new_instance").once().and_return(model) + model.should_receive("fill").once().with_args({"foo": "baz"}) + model.should_receive("save").once() + model.should_receive("set_attribute").once().with_args("foreign_key", 1) - self.assertEqual('bar', relation.update_or_create({'foo': 'bar'}, {'foo': 'baz'}).foo) + self.assertEqual( + "bar", relation.update_or_create({"foo": "bar"}, {"foo": "baz"}).foo + ) def test_update_updates_models_with_timestamps(self): relation = self._get_relation() - relation.get_related().should_receive('uses_timestamps').once().and_return(True) + relation.get_related().should_receive("uses_timestamps").once().and_return(True) now = pendulum.now() - relation.get_related().should_receive('fresh_timestamp').once().and_return(now) - relation.get_query().should_receive('update').once().with_args({'foo': 'bar', 'updated_at': now}).and_return('results') + relation.get_related().should_receive("fresh_timestamp").once().and_return(now) + relation.get_query().should_receive("update").once().with_args( + {"foo": "bar", "updated_at": now} + ).and_return("results") - self.assertEqual('results', relation.update(foo='bar')) + self.assertEqual("results", relation.update(foo="bar")) def test_relation_is_properly_initialized(self): relation = self._get_relation() model = flexmock(Model()) - model.should_receive('set_relation').once().with_args('foo', Collection) - models = relation.init_relation([model], 'foo') + model.should_receive("set_relation").once().with_args("foo", Collection) + models = relation.init_relation([model], "foo") self.assertEqual([model], models) def test_eager_constraints_are_properly_added(self): relation = self._get_relation() - relation.get_query().get_query().should_receive('where_in').once().with_args('table.foreign_key', [1, 2]) + relation.get_query().get_query().should_receive("where_in").once().with_args( + "table.foreign_key", [1, 2] + ) model1 = OrmHasOneModelStub() model1.id = 1 @@ -148,6 +185,21 @@ def test_eager_constraints_are_properly_added(self): relation.add_eager_constraints([model1, model2]) + def test_save_many_returns_list_of_models(self): + relation = self._get_relation() + + model1 = flexmock() + model1.foo = "foo" + model1.should_receive("save").once().and_return(True) + model1.should_receive("set_attribute").once().with_args("foreign_key", 1) + + model2 = flexmock() + model2.foo = "bar" + model2.should_receive("save").once().and_return(True) + model2.should_receive("set_attribute").once().with_args("foreign_key", 1) + + self.assertEqual([model1, model2], relation.save_many([model1, model2])) + def test_models_are_properly_matched_to_parents(self): relation = self._get_relation() @@ -165,11 +217,19 @@ def test_models_are_properly_matched_to_parents(self): model3 = OrmHasOneModelStub() model3.id = 3 - relation.get_related().should_receive('new_collection').replace_with(lambda l=None: Collection(l)) - relation.get_query().should_receive('where').with_args('table.foreign_key', '=', 2) - relation.get_query().should_receive('where').with_args('table.foreign_key', '=', 3) + relation.get_related().should_receive("new_collection").replace_with( + lambda l=None: Collection(l) + ) + relation.get_query().should_receive("where").with_args( + "table.foreign_key", "=", 2 + ) + relation.get_query().should_receive("where").with_args( + "table.foreign_key", "=", 3 + ) - models = relation.match([model1, model2, model3], Collection([result1, result2, result3]), 'foo') + models = relation.match( + [model1, model2, model3], Collection([result1, result2, result3]), "foo" + ) self.assertEqual(1, models[0].foo[0].foreign_key) self.assertEqual(1, len(models[0].foo)) @@ -182,14 +242,16 @@ def test_relation_count_query_can_be_built(self): relation = self._get_relation() query = flexmock(QueryBuilder(None, QueryGrammar(), None)) builder = Builder(query) - builder.get_query().should_receive('select').once() - relation.get_parent().should_receive('get_table').and_return('table') - builder.should_receive('where').once().with_args('table.foreign_key', '=', QueryExpression) + builder.get_query().should_receive("select").once() + relation.get_parent().should_receive("get_table").and_return("table") + builder.should_receive("where").once().with_args( + "table.foreign_key", "=", QueryExpression + ) parent_query = flexmock(QueryBuilder(None, None, None)) - relation.get_query().should_receive('get_query').and_return(parent_query) + relation.get_query().should_receive("get_query").and_return(parent_query) grammar = flexmock() - parent_query.should_receive('get_grammar').once().and_return(grammar) - grammar.should_receive('wrap').once().with_args('table.id') + parent_query.should_receive("get_grammar").once().and_return(grammar) + grammar.should_receive("wrap").once().with_args("table.id") relation.get_relation_count_query(builder, builder) @@ -197,17 +259,17 @@ def _get_relation(self): flexmock(Builder) query = flexmock(QueryBuilder(None, QueryGrammar(), None)) builder = Builder(query) - builder.should_receive('where').with_args('table.foreign_key', '=', 1) + builder.should_receive("where").with_args("table.foreign_key", "=", 1) related = flexmock(Model()) - related.should_receive('new_query').and_return(builder) - builder.should_receive('get_model').and_return(related) + related.should_receive("new_query").and_return(builder) + builder.should_receive("get_model").and_return(related) parent = flexmock(Model()) - parent.should_receive('get_attribute').with_args('id').and_return(1) - parent.should_receive('get_created_at_column').and_return('created_at') - parent.should_receive('get_updated_at_column').and_return('updated_at') - parent.should_receive('new_query').and_return(builder) + parent.should_receive("get_attribute").with_args("id").and_return(1) + parent.should_receive("get_created_at_column").and_return("created_at") + parent.should_receive("get_updated_at_column").and_return("updated_at") + parent.should_receive("new_query").and_return(builder) - return HasMany(builder, parent, 'table.foreign_key', 'id') + return HasMany(builder, parent, "table.foreign_key", "id") class OrmHasOneModelStub(Model): diff --git a/tests/orm/relations/test_has_many_through.py b/tests/orm/relations/test_has_many_through.py index ed2110c9..6a1c8f6b 100644 --- a/tests/orm/relations/test_has_many_through.py +++ b/tests/orm/relations/test_has_many_through.py @@ -16,22 +16,25 @@ class OrmHasManyThroughTestCase(OratorTestCase): - def tearDown(self): flexmock_teardown() def test_relation_is_properly_initialized(self): relation = self._get_relation() model = flexmock(Model()) - relation.get_related().should_receive('new_collection').replace_with(lambda l=None: Collection(l or [])) - model.should_receive('set_relation').once().with_args('foo', Collection) - models = relation.init_relation([model], 'foo') + relation.get_related().should_receive("new_collection").replace_with( + lambda l=None: Collection(l or []) + ) + model.should_receive("set_relation").once().with_args("foo", Collection) + models = relation.init_relation([model], "foo") self.assertEqual([model], models) def test_eager_constraints_are_properly_added(self): relation = self._get_relation() - relation.get_query().get_query().should_receive('where_in').once().with_args('users.country_id', [1, 2]) + relation.get_query().get_query().should_receive("where_in").once().with_args( + "users.country_id", [1, 2] + ) model1 = OrmHasManyThroughModelStub() model1.id = 1 model2 = OrmHasManyThroughModelStub() @@ -55,10 +58,18 @@ def test_models_are_properly_matched_to_parents(self): model3 = OrmHasManyThroughModelStub() model3.id = 3 - relation.get_related().should_receive('new_collection').replace_with(lambda l=None: Collection(l or [])) - relation.get_query().should_receive('where').with_args('users.country_id', '=', 2) - relation.get_query().should_receive('where').with_args('users.country_id', '=', 3) - models = relation.match([model1, model2, model3], Collection([result1, result2, result3]), 'foo') + relation.get_related().should_receive("new_collection").replace_with( + lambda l=None: Collection(l or []) + ) + relation.get_query().should_receive("where").with_args( + "users.country_id", "=", 2 + ) + relation.get_query().should_receive("where").with_args( + "users.country_id", "=", 3 + ) + models = relation.match( + [model1, model2, model3], Collection([result1, result2, result3]), "foo" + ) self.assertEqual(1, models[0].foo[0].country_id) self.assertEqual(1, len(models[0].foo)) @@ -71,8 +82,10 @@ def test_get(self): relation = self._get_relation() query = relation.get_query() - query.get_query().should_receive('add_select').once().with_args('posts.*', 'users.country_id').and_return(query) - query.should_receive('get_models').and_return([]) + query.get_query().should_receive("add_select").once().with_args( + "posts.*", "users.country_id" + ).and_return(query) + query.should_receive("get_models").and_return([]) relation.get() @@ -80,31 +93,33 @@ def _get_relation(self): flexmock(Builder) query = flexmock(QueryBuilder(None, QueryGrammar(), None)) builder = Builder(query) - builder.get_query().should_receive('join').at_least().once().with_args('users', 'users.id', '=', 'posts.user_id') - builder.should_receive('where').with_args('users.country_id', '=', 1) + builder.get_query().should_receive("join").at_least().once().with_args( + "users", "users.id", "=", "posts.user_id" + ) + builder.should_receive("where").with_args("users.country_id", "=", 1) country = flexmock(Model()) - country.should_receive('get_key').and_return(1) - country.should_receive('get_foreign_key').and_return('country_id') + country.should_receive("get_key").and_return(1) + country.should_receive("get_foreign_key").and_return("country_id") user = flexmock(Model()) - user.should_receive('get_table').and_return('users') - user.should_receive('get_qualified_key_name').and_return('users.id') + user.should_receive("get_table").and_return("users") + user.should_receive("get_qualified_key_name").and_return("users.id") post = flexmock(Model()) - post.should_receive('get_table').and_return('posts') - builder.should_receive('get_model').and_return(post) + post.should_receive("get_table").and_return("posts") + builder.should_receive("get_model").and_return(post) - post.should_receive('new_query').and_return(builder) + post.should_receive("new_query").and_return(builder) - user.should_receive('get_key').and_return(1) - user.should_receive('get_created_at_column').and_return('created_at') - user.should_receive('get_updated_at_column').and_return('updated_at') + user.should_receive("get_key").and_return(1) + user.should_receive("get_created_at_column").and_return("created_at") + user.should_receive("get_updated_at_column").and_return("updated_at") parent = flexmock(Model()) - parent.should_receive('get_attribute').with_args('id').and_return(1) - parent.should_receive('get_created_at_column').and_return('created_at') - parent.should_receive('get_updated_at_column').and_return('updated_at') - parent.should_receive('new_query').and_return(builder) + parent.should_receive("get_attribute").with_args("id").and_return(1) + parent.should_receive("get_created_at_column").and_return("created_at") + parent.should_receive("get_updated_at_column").and_return("updated_at") + parent.should_receive("new_query").and_return(builder) - return HasManyThrough(builder, country, user, 'country_id', 'user_id') + return HasManyThrough(builder, country, user, "country_id", "user_id") class OrmHasManyThroughModelStub(Model): diff --git a/tests/orm/relations/test_has_one.py b/tests/orm/relations/test_has_one.py index 93e21ed8..20fb48a0 100644 --- a/tests/orm/relations/test_has_one.py +++ b/tests/orm/relations/test_has_one.py @@ -15,48 +15,53 @@ class OrmHasOneTestCase(OratorTestCase): - def tearDown(self): flexmock_teardown() def test_save_method_set_foreign_key_on_model(self): relation = self._get_relation() mock_model = flexmock(Model(), save=lambda: True) - mock_model.should_receive('save').once().and_return(True) + mock_model.should_receive("save").once().and_return(True) result = relation.save(mock_model) attributes = result.get_attributes() - self.assertEqual(1, attributes['foreign_key']) + self.assertEqual(1, attributes["foreign_key"]) def test_create_properly_creates_new_model(self): relation = self._get_relation() created = flexmock(Model(), save=lambda: True, set_attribute=lambda: None) - created.should_receive('save').once().and_return(True) - relation.get_related().should_receive('new_instance').once().with_args({'name': 'john'}).and_return(created) - created.should_receive('set_attribute').with_args('foreign_key', 1) + created.should_receive("save").once().and_return(True) + relation.get_related().should_receive("new_instance").once().with_args( + {"name": "john"} + ).and_return(created) + created.should_receive("set_attribute").with_args("foreign_key", 1) - self.assertEqual(created, relation.create(name='john')) + self.assertEqual(created, relation.create(name="john")) def test_update_updates_models_with_timestamps(self): relation = self._get_relation() - relation.get_related().should_receive('uses_timestamps').once().and_return(True) + relation.get_related().should_receive("uses_timestamps").once().and_return(True) now = pendulum.now() - relation.get_related().should_receive('fresh_timestamp').once().and_return(now) - relation.get_query().should_receive('update').once().with_args({'foo': 'bar', 'updated_at': now}).and_return('results') + relation.get_related().should_receive("fresh_timestamp").once().and_return(now) + relation.get_query().should_receive("update").once().with_args( + {"foo": "bar", "updated_at": now} + ).and_return("results") - self.assertEqual('results', relation.update(foo='bar')) + self.assertEqual("results", relation.update(foo="bar")) def test_relation_is_properly_initialized(self): relation = self._get_relation() model = flexmock(Model()) - model.should_receive('set_relation').once().with_args('foo', None) - models = relation.init_relation([model], 'foo') + model.should_receive("set_relation").once().with_args("foo", None) + models = relation.init_relation([model], "foo") self.assertEqual([model], models) def test_eager_constraints_are_properly_added(self): relation = self._get_relation() - relation.get_query().get_query().should_receive('where_in').once().with_args('table.foreign_key', [1, 2]) + relation.get_query().get_query().should_receive("where_in").once().with_args( + "table.foreign_key", [1, 2] + ) model1 = OrmHasOneModelStub() model1.id = 1 @@ -80,10 +85,16 @@ def test_models_are_properly_matched_to_parents(self): model3 = OrmHasOneModelStub() model3.id = 3 - relation.get_query().should_receive('where').with_args('table.foreign_key', '=', 2) - relation.get_query().should_receive('where').with_args('table.foreign_key', '=', 3) + relation.get_query().should_receive("where").with_args( + "table.foreign_key", "=", 2 + ) + relation.get_query().should_receive("where").with_args( + "table.foreign_key", "=", 3 + ) - models = relation.match([model1, model2, model3], Collection([result1, result2]), 'foo') + models = relation.match( + [model1, model2, model3], Collection([result1, result2]), "foo" + ) self.assertEqual(1, models[0].foo.foreign_key) self.assertEqual(2, models[1].foo.foreign_key) @@ -93,14 +104,16 @@ def test_relation_count_query_can_be_built(self): relation = self._get_relation() query = flexmock(QueryBuilder(None, QueryGrammar(), None)) builder = Builder(query) - builder.get_query().should_receive('select').once() - relation.get_parent().should_receive('get_table').and_return('table') - builder.should_receive('where').once().with_args('table.foreign_key', '=', QueryExpression) + builder.get_query().should_receive("select").once() + relation.get_parent().should_receive("get_table").and_return("table") + builder.should_receive("where").once().with_args( + "table.foreign_key", "=", QueryExpression + ) parent_query = flexmock(QueryBuilder(None, None, None)) - relation.get_query().should_receive('get_query').and_return(parent_query) + relation.get_query().should_receive("get_query").and_return(parent_query) grammar = flexmock() - parent_query.should_receive('get_grammar').once().and_return(grammar) - grammar.should_receive('wrap').once().with_args('table.id') + parent_query.should_receive("get_grammar").once().and_return(grammar) + grammar.should_receive("wrap").once().with_args("table.id") relation.get_relation_count_query(builder, builder) @@ -108,18 +121,18 @@ def _get_relation(self): flexmock(Builder) query = flexmock(QueryBuilder(None, QueryGrammar(), None)) builder = Builder(query) - builder.should_receive('where').with_args('table.foreign_key', '=', 1) + builder.should_receive("where").with_args("table.foreign_key", "=", 1) related = flexmock(Model()) related_query = QueryBuilder(None, QueryGrammar(), None) - related.should_receive('new_query').and_return(Builder(related_query)) - builder.should_receive('get_model').and_return(related) + related.should_receive("new_query").and_return(Builder(related_query)) + builder.should_receive("get_model").and_return(related) parent = flexmock(Model()) - parent.should_receive('get_attribute').with_args('id').and_return(1) - parent.should_receive('get_created_at_column').and_return('created_at') - parent.should_receive('get_updated_at_column').and_return('updated_at') - parent.should_receive('new_query').and_return(builder) + parent.should_receive("get_attribute").with_args("id").and_return(1) + parent.should_receive("get_created_at_column").and_return("created_at") + parent.should_receive("get_updated_at_column").and_return("updated_at") + parent.should_receive("new_query").and_return(builder) - return HasOne(builder, parent, 'table.foreign_key', 'id') + return HasOne(builder, parent, "table.foreign_key", "id") class OrmHasOneModelStub(Model): diff --git a/tests/orm/relations/test_morph.py b/tests/orm/relations/test_morph.py index 714a2bb2..29b41c02 100644 --- a/tests/orm/relations/test_morph.py +++ b/tests/orm/relations/test_morph.py @@ -14,7 +14,6 @@ class OrmMorphTestCase(OratorTestCase): - def tearDown(self): flexmock_teardown() @@ -23,9 +22,12 @@ def test_morph_one_sets_proper_constraints(self): def test_morph_one_eager_constraints_are_properly_added(self): relation = self._get_one_relation() - relation.get_query().get_query().should_receive('where_in').once().with_args('table.morph_id', [1, 2]) - relation.get_query().should_receive('where').once()\ - .with_args('table.morph_type', relation.get_parent().__class__.__name__) + relation.get_query().get_query().should_receive("where_in").once().with_args( + "table.morph_id", [1, 2] + ) + relation.get_query().should_receive("where").once().with_args( + "table.morph_type", relation.get_parent().__class__.__name__ + ) model1 = Model() model1.id = 1 @@ -38,9 +40,12 @@ def test_morph_many_sets_proper_constraints(self): def test_morph_many_eager_constraints_are_properly_added(self): relation = self._get_many_relation() - relation.get_query().get_query().should_receive('where_in').once().with_args('table.morph_id', [1, 2]) - relation.get_query().should_receive('where').once()\ - .with_args('table.morph_type', relation.get_parent().__class__.__name__) + relation.get_query().get_query().should_receive("where_in").once().with_args( + "table.morph_id", [1, 2] + ) + relation.get_query().should_receive("where").once().with_args( + "table.morph_type", relation.get_parent().__class__.__name__ + ) model1 = Model() model1.id = 1 @@ -51,141 +56,190 @@ def test_morph_many_eager_constraints_are_properly_added(self): def test_create(self): relation = self._get_one_relation() created = flexmock(Model()) - created.should_receive('set_attribute').once().with_args('morph_id', 1) - created.should_receive('set_attribute').once()\ - .with_args('morph_type', relation.get_parent().__class__.__name__) - relation.get_related().should_receive('new_instance').once().with_args({'name': 'john'}).and_return(created) - created.should_receive('save').once().and_return(True) + created.should_receive("set_attribute").once().with_args("morph_id", 1) + created.should_receive("set_attribute").once().with_args( + "morph_type", relation.get_parent().__class__.__name__ + ) + relation.get_related().should_receive("new_instance").once().with_args( + {"name": "john"} + ).and_return(created) + created.should_receive("save").once().and_return(True) - self.assertEqual(created, relation.create(name='john')) + self.assertEqual(created, relation.create(name="john")) def test_find_or_new_finds_model(self): relation = self._get_one_relation() model = flexmock() - model.foo = 'bar' - relation.get_query().should_receive('find').once().with_args('foo', ['*']).and_return(model) - relation.get_related().should_receive('new_instance').never() - model.should_receive('set_attribute').never() - model.should_receive('save').never() + model.foo = "bar" + relation.get_query().should_receive("find").once().with_args( + "foo", ["*"] + ).and_return(model) + relation.get_related().should_receive("new_instance").never() + model.should_receive("set_attribute").never() + model.should_receive("save").never() - self.assertEqual('bar', relation.find_or_new('foo').foo) + self.assertEqual("bar", relation.find_or_new("foo").foo) def test_find_or_new_returns_new_model_with_morph_keys_set(self): relation = self._get_one_relation() model = flexmock() - model.foo = 'bar' - relation.get_query().should_receive('find').once().with_args('foo', ['*']).and_return(None) - relation.get_related().should_receive('new_instance').once().with_args().and_return(model) - model.should_receive('set_attribute').once().with_args('morph_id', 1) - model.should_receive('set_attribute').once().with_args('morph_type', relation.get_parent().__class__.__name__) - model.should_receive('save').never() - - self.assertEqual('bar', relation.find_or_new('foo').foo) + model.foo = "bar" + relation.get_query().should_receive("find").once().with_args( + "foo", ["*"] + ).and_return(None) + relation.get_related().should_receive( + "new_instance" + ).once().with_args().and_return(model) + model.should_receive("set_attribute").once().with_args("morph_id", 1) + model.should_receive("set_attribute").once().with_args( + "morph_type", relation.get_parent().__class__.__name__ + ) + model.should_receive("save").never() + + self.assertEqual("bar", relation.find_or_new("foo").foo) def test_first_or_new_returns_first_model(self): relation = self._get_one_relation() model = flexmock() - model.foo = 'bar' - relation.get_query().should_receive('where').once().with_args({'foo': 'bar'}).and_return(relation.get_query()) - relation.get_query().should_receive('first').once().with_args().and_return(model) - relation.get_related().should_receive('new_instance').never() - model.should_receive('set_attribute').never() - model.should_receive('set_attribute').never() - model.should_receive('save').never() - - self.assertEqual('bar', relation.first_or_new({'foo': 'bar'}).foo) + model.foo = "bar" + relation.get_query().should_receive("where").once().with_args( + {"foo": "bar"} + ).and_return(relation.get_query()) + relation.get_query().should_receive("first").once().with_args().and_return( + model + ) + relation.get_related().should_receive("new_instance").never() + model.should_receive("set_attribute").never() + model.should_receive("set_attribute").never() + model.should_receive("save").never() + + self.assertEqual("bar", relation.first_or_new({"foo": "bar"}).foo) def test_first_or_new_returns_new_model_with_morph_keys_set(self): relation = self._get_one_relation() model = flexmock() - model.foo = 'bar' - relation.get_query().should_receive('where').once().with_args({'foo': 'bar'}).and_return(relation.get_query()) - relation.get_query().should_receive('first').once().with_args().and_return(None) - relation.get_related().should_receive('new_instance').once().with_args().and_return(model) - model.should_receive('set_attribute').once().with_args('morph_id', 1) - model.should_receive('set_attribute').once().with_args('morph_type', relation.get_parent().__class__.__name__) - model.should_receive('save').never() - - self.assertEqual('bar', relation.first_or_new({'foo': 'bar'}).foo) + model.foo = "bar" + relation.get_query().should_receive("where").once().with_args( + {"foo": "bar"} + ).and_return(relation.get_query()) + relation.get_query().should_receive("first").once().with_args().and_return(None) + relation.get_related().should_receive( + "new_instance" + ).once().with_args().and_return(model) + model.should_receive("set_attribute").once().with_args("morph_id", 1) + model.should_receive("set_attribute").once().with_args( + "morph_type", relation.get_parent().__class__.__name__ + ) + model.should_receive("save").never() + + self.assertEqual("bar", relation.first_or_new({"foo": "bar"}).foo) def test_first_or_create_returns_first_model(self): relation = self._get_one_relation() model = flexmock() - model.foo = 'bar' - relation.get_query().should_receive('where').once().with_args({'foo': 'bar'}).and_return(relation.get_query()) - relation.get_query().should_receive('first').once().with_args().and_return(model) - relation.get_related().should_receive('new_instance').never() - model.should_receive('set_attribute').never() - model.should_receive('set_attribute').never() - model.should_receive('save').never() - - self.assertEqual('bar', relation.first_or_create({'foo': 'bar'}).foo) + model.foo = "bar" + relation.get_query().should_receive("where").once().with_args( + {"foo": "bar"} + ).and_return(relation.get_query()) + relation.get_query().should_receive("first").once().with_args().and_return( + model + ) + relation.get_related().should_receive("new_instance").never() + model.should_receive("set_attribute").never() + model.should_receive("set_attribute").never() + model.should_receive("save").never() + + self.assertEqual("bar", relation.first_or_create({"foo": "bar"}).foo) def test_first_or_create_creates_new_morph_model(self): relation = self._get_one_relation() model = flexmock() - model.foo = 'bar' - relation.get_query().should_receive('where').once().with_args({'foo': 'bar'}).and_return(relation.get_query()) - relation.get_query().should_receive('first').once().with_args().and_return(None) - relation.get_related().should_receive('new_instance').once().with_args({'foo': 'bar'}).and_return(model) - model.should_receive('set_attribute').once().with_args('morph_id', 1) - model.should_receive('set_attribute').once().with_args('morph_type', relation.get_parent().__class__.__name__) - model.should_receive('save').once().and_return() - - self.assertEqual('bar', relation.first_or_create({'foo': 'bar'}).foo) + model.foo = "bar" + relation.get_query().should_receive("where").once().with_args( + {"foo": "bar"} + ).and_return(relation.get_query()) + relation.get_query().should_receive("first").once().with_args().and_return(None) + relation.get_related().should_receive("new_instance").once().with_args( + {"foo": "bar"} + ).and_return(model) + model.should_receive("set_attribute").once().with_args("morph_id", 1) + model.should_receive("set_attribute").once().with_args( + "morph_type", relation.get_parent().__class__.__name__ + ) + model.should_receive("save").once().and_return() + + self.assertEqual("bar", relation.first_or_create({"foo": "bar"}).foo) def test_update_or_create_finds_first_model_and_updates(self): relation = self._get_one_relation() model = flexmock() - model.foo = 'bar' - relation.get_query().should_receive('where').once().with_args({'foo': 'bar'}).and_return(relation.get_query()) - relation.get_query().should_receive('first').once().with_args().and_return(model) - relation.get_related().should_receive('new_instance').never() - model.should_receive('set_attribute').never() - model.should_receive('set_attribute').never() - model.should_receive('fill').once().with_args({'bar': 'baz'}) - model.should_receive('save').once().and_return(True) - - self.assertEqual('bar', relation.update_or_create({'foo': 'bar'}, {'bar': 'baz'}).foo) + model.foo = "bar" + relation.get_query().should_receive("where").once().with_args( + {"foo": "bar"} + ).and_return(relation.get_query()) + relation.get_query().should_receive("first").once().with_args().and_return( + model + ) + relation.get_related().should_receive("new_instance").never() + model.should_receive("set_attribute").never() + model.should_receive("set_attribute").never() + model.should_receive("fill").once().with_args({"bar": "baz"}) + model.should_receive("save").once().and_return(True) + + self.assertEqual( + "bar", relation.update_or_create({"foo": "bar"}, {"bar": "baz"}).foo + ) def test_update_or_create_finds_creates_new_morph_model(self): relation = self._get_one_relation() model = flexmock() - model.foo = 'bar' - relation.get_query().should_receive('where').once().with_args({'foo': 'bar'}).and_return(relation.get_query()) - relation.get_query().should_receive('first').once().with_args().and_return(None) - relation.get_related().should_receive('new_instance').once().with_args().and_return(model) - model.should_receive('set_attribute').once().with_args('morph_id', 1) - model.should_receive('set_attribute').once().with_args('morph_type', relation.get_parent().__class__.__name__) - model.should_receive('fill').once().with_args({'bar': 'baz'}) - model.should_receive('save').once().and_return(True) - - self.assertEqual('bar', relation.update_or_create({'foo': 'bar'}, {'bar': 'baz'}).foo) + model.foo = "bar" + relation.get_query().should_receive("where").once().with_args( + {"foo": "bar"} + ).and_return(relation.get_query()) + relation.get_query().should_receive("first").once().with_args().and_return(None) + relation.get_related().should_receive( + "new_instance" + ).once().with_args().and_return(model) + model.should_receive("set_attribute").once().with_args("morph_id", 1) + model.should_receive("set_attribute").once().with_args( + "morph_type", relation.get_parent().__class__.__name__ + ) + model.should_receive("fill").once().with_args({"bar": "baz"}) + model.should_receive("save").once().and_return(True) + + self.assertEqual( + "bar", relation.update_or_create({"foo": "bar"}, {"bar": "baz"}).foo + ) def _get_many_relation(self): flexmock(Builder) query = flexmock(QueryBuilder(None, QueryGrammar(), None)) builder = Builder(query) - builder.should_receive('where').with_args('table.morph_id', '=', 1) + builder.should_receive("where").with_args("table.morph_id", "=", 1) related = flexmock(Model()) - builder.should_receive('get_model').and_return(related) + builder.should_receive("get_model").and_return(related) parent = flexmock(Model()) - parent.should_receive('get_attribute').with_args('id').and_return(1) - parent.should_receive('get_morph_name').and_return(parent.__class__.__name__) - builder.should_receive('where').once().with_args('table.morph_type', parent.__class__.__name__) + parent.should_receive("get_attribute").with_args("id").and_return(1) + parent.should_receive("get_morph_name").and_return(parent.__class__.__name__) + builder.should_receive("where").once().with_args( + "table.morph_type", parent.__class__.__name__ + ) - return MorphMany(builder, parent, 'table.morph_type', 'table.morph_id', 'id') + return MorphMany(builder, parent, "table.morph_type", "table.morph_id", "id") def _get_one_relation(self): flexmock(Builder) query = flexmock(QueryBuilder(None, QueryGrammar(), None)) builder = Builder(query) - builder.should_receive('where').with_args('table.morph_id', '=', 1) + builder.should_receive("where").with_args("table.morph_id", "=", 1) related = flexmock(Model()) - builder.should_receive('get_model').and_return(related) + builder.should_receive("get_model").and_return(related) parent = flexmock(Model()) - parent.should_receive('get_attribute').with_args('id').and_return(1) - parent.should_receive('get_morph_name').and_return(parent.__class__.__name__) - builder.should_receive('where').once().with_args('table.morph_type', parent.__class__.__name__) + parent.should_receive("get_attribute").with_args("id").and_return(1) + parent.should_receive("get_morph_name").and_return(parent.__class__.__name__) + builder.should_receive("where").once().with_args( + "table.morph_type", parent.__class__.__name__ + ) - return MorphOne(builder, parent, 'table.morph_type', 'table.morph_id', 'id') + return MorphOne(builder, parent, "table.morph_type", "table.morph_id", "id") diff --git a/tests/orm/relations/test_morph_to.py b/tests/orm/relations/test_morph_to.py index d761b2ea..cbe12b32 100644 --- a/tests/orm/relations/test_morph_to.py +++ b/tests/orm/relations/test_morph_to.py @@ -14,7 +14,6 @@ class OrmMorphToTestCase(OratorTestCase): - def tearDown(self): flexmock_teardown() @@ -22,73 +21,79 @@ def test_lookup_dictionary_is_properly_constructed(self): relation = self._get_relation() one = flexmock() - one.morph_type = 'morph_type_1' - one.foreign_key = 'foreign_key_1' + one.morph_type = "morph_type_1" + one.foreign_key = "foreign_key_1" two = flexmock() - two.morph_type = 'morph_type_1' - two.foreign_key = 'foreign_key_1' + two.morph_type = "morph_type_1" + two.foreign_key = "foreign_key_1" three = flexmock() - three.morph_type = 'morph_type_2' - three.foreign_key = 'foreign_key_2' + three.morph_type = "morph_type_2" + three.foreign_key = "foreign_key_2" relation.add_eager_constraints([one, two, three]) dictionary = relation.get_dictionary() - self.assertEqual({ - 'morph_type_1': { - 'foreign_key_1': [ - one, - two - ] + self.assertEqual( + { + "morph_type_1": {"foreign_key_1": [one, two]}, + "morph_type_2": {"foreign_key_2": [three]}, }, - 'morph_type_2': { - 'foreign_key_2': [three] - } - }, dictionary) + dictionary, + ) def test_models_are_properly_pulled_and_matched(self): relation = self._get_relation() one = flexmock(Model()) - one.morph_type = 'morph_type_1' - one.foreign_key = 'foreign_key_1' + one.morph_type = "morph_type_1" + one.foreign_key = "foreign_key_1" two = flexmock(Model()) - two.morph_type = 'morph_type_1' - two.foreign_key = 'foreign_key_1' + two.morph_type = "morph_type_1" + two.foreign_key = "foreign_key_1" three = flexmock(Model()) - three.morph_type = 'morph_type_2' - three.foreign_key = 'foreign_key_2' + three.morph_type = "morph_type_2" + three.foreign_key = "foreign_key_2" relation.add_eager_constraints([one, two, three]) - first_query = flexmock(Builder(flexmock(QueryBuilder(None, QueryGrammar(), None)))) - second_query = flexmock(Builder(flexmock(QueryBuilder(None, QueryGrammar(), None)))) + first_query = flexmock( + Builder(flexmock(QueryBuilder(None, QueryGrammar(), None))) + ) + second_query = flexmock( + Builder(flexmock(QueryBuilder(None, QueryGrammar(), None))) + ) first_model = flexmock(Model()) second_model = flexmock(Model()) - relation.should_receive('_create_model_by_type').once().with_args('morph_type_1').and_return(first_model) - relation.should_receive('_create_model_by_type').once().with_args('morph_type_2').and_return(second_model) - first_model.should_receive('get_key_name').and_return('id') - second_model.should_receive('get_key_name').and_return('id') - - first_model.should_receive('new_query').once().and_return(first_query) - second_model.should_receive('new_query').once().and_return(second_query) - - first_query.get_query().should_receive('where_in').once()\ - .with_args('id', ['foreign_key_1']).and_return(first_query) + relation.should_receive("_create_model_by_type").once().with_args( + "morph_type_1" + ).and_return(first_model) + relation.should_receive("_create_model_by_type").once().with_args( + "morph_type_2" + ).and_return(second_model) + first_model.should_receive("get_key_name").and_return("id") + second_model.should_receive("get_key_name").and_return("id") + + first_model.should_receive("new_query").once().and_return(first_query) + second_model.should_receive("new_query").once().and_return(second_query) + + first_query.get_query().should_receive("where_in").once().with_args( + "id", ["foreign_key_1"] + ).and_return(first_query) result_one = flexmock() - first_query.should_receive('get').and_return(Collection.make([result_one])) - result_one.should_receive('get_key').and_return('foreign_key_1') + first_query.should_receive("get").and_return(Collection.make([result_one])) + result_one.should_receive("get_key").and_return("foreign_key_1") - second_query.get_query().should_receive('where_in').once()\ - .with_args('id', ['foreign_key_2']).and_return(second_query) + second_query.get_query().should_receive("where_in").once().with_args( + "id", ["foreign_key_2"] + ).and_return(second_query) result_two = flexmock() - second_query.should_receive('get').and_return(Collection.make([result_two])) - result_two.should_receive('get_key').and_return('foreign_key_2') + second_query.should_receive("get").and_return(Collection.make([result_two])) + result_two.should_receive("get_key").and_return("foreign_key_2") - one.should_receive('set_relation').once().with_args('relation', result_one) - two.should_receive('set_relation').once().with_args('relation', result_one) - three.should_receive('set_relation').once().with_args('relation', result_two) + one.should_receive("set_relation").once().with_args("relation", result_one) + two.should_receive("set_relation").once().with_args("relation", result_one) + three.should_receive("set_relation").once().with_args("relation", result_two) relation.get_eager() @@ -96,17 +101,19 @@ def test_models_are_properly_pulled_and_matched(self): def test_associate_sets_foreign_key_and_type_on_model(self): parent = flexmock(Model()) - parent.should_receive('get_attribute').once().with_args('foreign_key').and_return('foreign.value') + parent.should_receive("get_attribute").once().with_args( + "foreign_key" + ).and_return("foreign.value") relation = self._get_relation_associate(parent) associate = flexmock(Model()) - associate.should_receive('get_key').once().and_return(1) - associate.should_receive('get_morph_name').once().and_return('Model') + associate.should_receive("get_key").once().and_return(1) + associate.should_receive("get_morph_name").once().and_return("Model") - parent.should_receive('set_attribute').once().with_args('foreign_key', 1) - parent.should_receive('set_attribute').once().with_args('morph_type', 'Model') - parent.should_receive('set_relation').once().with_args('relation', associate) + parent.should_receive("set_attribute").once().with_args("foreign_key", 1) + parent.should_receive("set_attribute").once().with_args("morph_type", "Model") + parent.should_receive("set_relation").once().with_args("relation", associate) relation.associate(associate) @@ -114,30 +121,30 @@ def _get_relation_associate(self, parent): flexmock(Builder) query = flexmock(QueryBuilder(None, QueryGrammar(), None)) builder = Builder(query) - builder.should_receive('where').with_args('relation.id', '=', 'foreign.value') + builder.should_receive("where").with_args("relation.id", "=", "foreign.value") related = flexmock(Model()) - related.should_receive('get_key').and_return(1) - related.should_receive('get_table').and_return('relation') - builder.should_receive('get_model').and_return(related) + related.should_receive("get_key").and_return(1) + related.should_receive("get_table").and_return("relation") + builder.should_receive("get_model").and_return(related) - return MorphTo(builder, parent, 'foreign_key', 'id', 'morph_type', 'relation') + return MorphTo(builder, parent, "foreign_key", "id", "morph_type", "relation") def _get_relation(self, parent=None, builder=None): flexmock(Builder) query = flexmock(QueryBuilder(None, QueryGrammar(), None)) builder = builder or Builder(query) - builder.should_receive('where').with_args('relation.id', '=', 'foreign.value') + builder.should_receive("where").with_args("relation.id", "=", "foreign.value") related = flexmock(Model()) - related.should_receive('get_key').and_return(1) - related.should_receive('get_table').and_return('relation') - builder.should_receive('get_model').and_return(related) + related.should_receive("get_key").and_return(1) + related.should_receive("get_table").and_return("relation") + builder.should_receive("get_model").and_return(related) parent = parent or OrmMorphToModelStub() flexmock(MorphTo) - return MorphTo(builder, parent, 'foreign_key', 'id', 'morph_type', 'relation') + return MorphTo(builder, parent, "foreign_key", "id", "morph_type", "relation") class OrmMorphToModelStub(Model): - foreign_key = 'foreign.value' + foreign_key = "foreign.value" diff --git a/tests/orm/relations/test_morph_to_many.py b/tests/orm/relations/test_morph_to_many.py index 62e54add..609ce528 100644 --- a/tests/orm/relations/test_morph_to_many.py +++ b/tests/orm/relations/test_morph_to_many.py @@ -18,15 +18,17 @@ class OrmMorphToManyTestCase(OratorTestCase): - def tearDown(self): flexmock_teardown() def test_eager_constraints_are_properly_added(self): relation = self._get_relation() - relation.get_query().get_query().should_receive('where_in').once().with_args('taggables.taggable_id', [1, 2]) - relation.get_query().should_receive('where').once()\ - .with_args('taggables.taggable_type', relation.get_parent().__class__.__name__) + relation.get_query().get_query().should_receive("where_in").once().with_args( + "taggables.taggable_id", [1, 2] + ) + relation.get_query().should_receive("where").once().with_args( + "taggables.taggable_type", relation.get_parent().__class__.__name__ + ) model1 = OrmMorphToManyModelStub() model1.id = 1 model2 = OrmMorphToManyModelStub() @@ -38,37 +40,41 @@ def test_attach_inserts_pivot_table_record(self): flexmock(MorphToMany, touch_if_touching=lambda: True) relation = self._get_relation() query = flexmock() - query.should_receive('from_').once().with_args('taggables').and_return(query) - query.should_receive('insert').once()\ - .with_args( - [{ - 'taggable_id': 1, - 'taggable_type': relation.get_parent().__class__.__name__, - 'tag_id': 2, - 'foo': 'bar', - }])\ - .and_return(True) + query.should_receive("from_").once().with_args("taggables").and_return(query) + query.should_receive("insert").once().with_args( + [ + { + "taggable_id": 1, + "taggable_type": relation.get_parent().__class__.__name__, + "tag_id": 2, + "foo": "bar", + } + ] + ).and_return(True) mock_query_builder = flexmock() - relation.get_query().should_receive('get_query').and_return(mock_query_builder) - mock_query_builder.should_receive('new_query').once().and_return(query) - relation.should_receive('touch_if_touching').once() + relation.get_query().should_receive("get_query").and_return(mock_query_builder) + mock_query_builder.should_receive("new_query").once().and_return(query) + relation.should_receive("touch_if_touching").once() - relation.attach(2, {'foo': 'bar'}) + relation.attach(2, {"foo": "bar"}) def test_detach_remove_pivot_table_record(self): flexmock(MorphToMany, touch_if_touching=lambda: True) relation = self._get_relation() query = flexmock() - query.should_receive('from_').once().with_args('taggables').and_return(query) - query.should_receive('where').once().with_args('taggable_id', 1).and_return(query) - query.should_receive('where').once()\ - .with_args('taggable_type', relation.get_parent().__class__.__name__).and_return(query) - query.should_receive('where_in').once().with_args('tag_id', [1, 2, 3]) - query.should_receive('delete').once().and_return(True) + query.should_receive("from_").once().with_args("taggables").and_return(query) + query.should_receive("where").once().with_args("taggable_id", 1).and_return( + query + ) + query.should_receive("where").once().with_args( + "taggable_type", relation.get_parent().__class__.__name__ + ).and_return(query) + query.should_receive("where_in").once().with_args("tag_id", [1, 2, 3]) + query.should_receive("delete").once().and_return(True) mock_query_builder = flexmock() - relation.get_query().should_receive('get_query').and_return(mock_query_builder) - mock_query_builder.should_receive('new_query').once().and_return(query) - relation.should_receive('touch_if_touching').once() + relation.get_query().should_receive("get_query").and_return(mock_query_builder) + mock_query_builder.should_receive("new_query").once().and_return(query) + relation.should_receive("touch_if_touching").once() self.assertTrue(relation.detach([1, 2, 3])) @@ -76,48 +82,72 @@ def test_detach_clears_all_records_when_no_ids(self): flexmock(MorphToMany, touch_if_touching=lambda: True) relation = self._get_relation() query = flexmock() - query.should_receive('from_').once().with_args('taggables').and_return(query) - query.should_receive('where').once().with_args('taggable_id', 1).and_return(query) - query.should_receive('where').once()\ - .with_args('taggable_type', relation.get_parent().__class__.__name__).and_return(query) - query.should_receive('where_in').never() - query.should_receive('delete').once().and_return(True) + query.should_receive("from_").once().with_args("taggables").and_return(query) + query.should_receive("where").once().with_args("taggable_id", 1).and_return( + query + ) + query.should_receive("where").once().with_args( + "taggable_type", relation.get_parent().__class__.__name__ + ).and_return(query) + query.should_receive("where_in").never() + query.should_receive("delete").once().and_return(True) mock_query_builder = flexmock() - relation.get_query().should_receive('get_query').and_return(mock_query_builder) - mock_query_builder.should_receive('new_query').once().and_return(query) - relation.should_receive('touch_if_touching').once() + relation.get_query().should_receive("get_query").and_return(mock_query_builder) + mock_query_builder.should_receive("new_query").once().and_return(query) + relation.should_receive("touch_if_touching").once() self.assertTrue(relation.detach()) def _get_relation(self): builder, parent = self._get_relation_arguments()[:2] - return MorphToMany(builder, parent, 'taggable', 'taggables', 'taggable_id', 'tag_id') + return MorphToMany( + builder, parent, "taggable", "taggables", "taggable_id", "tag_id" + ) def _get_relation_arguments(self): parent = flexmock(Model()) - parent.should_receive('get_morph_name').and_return(parent.__class__.__name__) - parent.should_receive('get_key').and_return(1) - parent.should_receive('get_created_at_column').and_return('created_at') - parent.should_receive('get_updated_at_column').and_return('updated_at') - - query = flexmock(QueryBuilder(MockConnection().prepare_mock(), QueryGrammar(), QueryProcessor())) + parent.should_receive("get_morph_name").and_return(parent.__class__.__name__) + parent.should_receive("get_key").and_return(1) + parent.should_receive("get_created_at_column").and_return("created_at") + parent.should_receive("get_updated_at_column").and_return("updated_at") + + query = flexmock( + QueryBuilder( + MockConnection().prepare_mock(), QueryGrammar(), QueryProcessor() + ) + ) flexmock(Builder) builder = Builder(query) - builder.should_receive('get_query').and_return(query) + builder.should_receive("get_query").and_return(query) related = flexmock(Model()) builder.set_model(related) - builder.should_receive('get_model').and_return(related) - - related.should_receive('get_key_name').and_return('id') - related.should_receive('get_table').and_return('tags') - related.should_receive('get_morph_name').and_return(parent.__class__.__name__) - - builder.get_query().should_receive('join').once().with_args('taggables', 'tags.id', '=', 'taggables.tag_id') - builder.should_receive('where').once().with_args('taggables.taggable_id', '=', 1) - builder.should_receive('where').once().with_args('taggables.taggable_type', parent.__class__.__name__) - - return builder, parent, 'taggable', 'taggables', 'taggable_id', 'tag_id', 'relation_name', False + builder.should_receive("get_model").and_return(related) + + related.should_receive("get_key_name").and_return("id") + related.should_receive("get_table").and_return("tags") + related.should_receive("get_morph_name").and_return(parent.__class__.__name__) + + builder.get_query().should_receive("join").once().with_args( + "taggables", "tags.id", "=", "taggables.tag_id" + ) + builder.should_receive("where").once().with_args( + "taggables.taggable_id", "=", 1 + ) + builder.should_receive("where").once().with_args( + "taggables.taggable_type", parent.__class__.__name__ + ) + + return ( + builder, + parent, + "taggable", + "taggables", + "taggable_id", + "tag_id", + "relation_name", + False, + ) class OrmMorphToManyModelStub(Model): diff --git a/tests/orm/relations/test_relation.py b/tests/orm/relations/test_relation.py index 2778cf12..7188eddf 100644 --- a/tests/orm/relations/test_relation.py +++ b/tests/orm/relations/test_relation.py @@ -12,37 +12,37 @@ class OrmRelationTestCase(OratorTestCase): - def tearDown(self): flexmock_teardown() def test_set_relation_fail(self): parent = OrmRelationResetModelStub() relation = OrmRelationResetModelStub() - parent.set_relation('test', relation) - parent.set_relation('foo', 'bar') - self.assertFalse('foo' in parent.to_dict()) + parent.set_relation("test", relation) + parent.set_relation("foo", "bar") + self.assertFalse("foo" in parent.to_dict()) def test_touch_method_updates_related_timestamps(self): builder = flexmock(Builder, get_model=None, where=None) parent = Model() parent = flexmock(parent) - parent.should_receive('get_attribute').with_args('id').and_return(1) + parent.should_receive("get_attribute").with_args("id").and_return(1) related = Model() related = flexmock(related) - builder.should_receive('get_model').and_return(related) - builder.should_receive('where') - relation = HasOne(Builder(QueryBuilder(None, None, None)), parent, 'foreign_key', 'id') - related.should_receive('get_table').and_return('table') - related.should_receive('get_updated_at_column').and_return('updated_at') + builder.should_receive("get_model").and_return(related) + builder.should_receive("where") + relation = HasOne( + Builder(QueryBuilder(None, None, None)), parent, "foreign_key", "id" + ) + related.should_receive("get_table").and_return("table") + related.should_receive("get_updated_at_column").and_return("updated_at") now = pendulum.now() - related.should_receive('fresh_timestamp').and_return(now) - builder.should_receive('update').once().with_args({'updated_at': now}) + related.should_receive("fresh_timestamp").and_return(now) + builder.should_receive("update").once().with_args({"updated_at": now}) relation.touch() class OrmRelationResetModelStub(Model): - def get_query(self): return self.new_query().get_query() diff --git a/tests/orm/scopes/__init__.py b/tests/orm/scopes/__init__.py index 633f8661..40a96afc 100644 --- a/tests/orm/scopes/__init__.py +++ b/tests/orm/scopes/__init__.py @@ -1,2 +1 @@ # -*- coding: utf-8 -*- - diff --git a/tests/orm/scopes/test_soft_deleting.py b/tests/orm/scopes/test_soft_deleting.py index 12468524..220e9924 100644 --- a/tests/orm/scopes/test_soft_deleting.py +++ b/tests/orm/scopes/test_soft_deleting.py @@ -8,7 +8,6 @@ class SoftDeletingScopeTestCase(OratorTestCase): - def tearDown(self): flexmock_teardown() @@ -17,8 +16,12 @@ def test_apply_scope_to_a_builder(self): query = flexmock(QueryBuilder(None, None, None)) builder = Builder(query) model = flexmock(ModelStub()) - model.should_receive('get_qualified_deleted_at_column').once().and_return('table.deleted_at') - builder.get_query().should_receive('where_null').once().with_args('table.deleted_at') + model.should_receive("get_qualified_deleted_at_column").once().and_return( + "table.deleted_at" + ) + builder.get_query().should_receive("where_null").once().with_args( + "table.deleted_at" + ) scope.apply(builder, model) @@ -26,10 +29,10 @@ def test_force_delete_extension(self): scope = SoftDeletingScope() builder = Builder(None) scope.extend(builder) - callback = builder.get_macro('force_delete') + callback = builder.get_macro("force_delete") query = flexmock(QueryBuilder(None, None, None)) given_builder = Builder(query) - query.should_receive('delete').once() + query.should_receive("delete").once() callback(given_builder) @@ -37,13 +40,13 @@ def test_restore_extension(self): scope = SoftDeletingScope() builder = Builder(None) scope.extend(builder) - callback = builder.get_macro('restore') + callback = builder.get_macro("restore") query = flexmock(QueryBuilder(None, None, None)) builder_mock = flexmock(BuilderWithTrashedStub) given_builder = BuilderWithTrashedStub(query) - builder_mock.should_receive('with_trashed').once() - builder_mock.should_receive('get_model').once().and_return(ModelStub()) - builder_mock.should_receive('update').once().with_args({'deleted_at': None}) + builder_mock.should_receive("with_trashed").once() + builder_mock.should_receive("get_model").once().and_return(ModelStub()) + builder_mock.should_receive("update").once().with_args({"deleted_at": None}) callback(given_builder) @@ -51,7 +54,7 @@ def test_with_trashed_extension(self): scope = flexmock(SoftDeletingScope()) builder = Builder(None) scope.extend(builder) - callback = builder.get_macro('with_trashed') + callback = builder.get_macro("with_trashed") query = flexmock(QueryBuilder(None, None, None)) given_builder = Builder(query) model = flexmock(ModelStub()) @@ -65,15 +68,13 @@ def test_with_trashed_extension(self): class ModelStub(Model): - def get_qualified_deleted_at_column(self): - return 'table.deleted_at' + return "table.deleted_at" def get_deleted_at_column(self): - return 'deleted_at' + return "deleted_at" class BuilderWithTrashedStub(Builder): - def with_trashed(self): pass diff --git a/tests/orm/test_builder.py b/tests/orm/test_builder.py index b8975d7f..6e9c6556 100644 --- a/tests/orm/test_builder.py +++ b/tests/orm/test_builder.py @@ -15,7 +15,6 @@ class BuilderTestCase(OratorTestCase): - def tearDown(self): flexmock_teardown() @@ -23,30 +22,26 @@ def test_find_method(self): builder = Builder(self.get_mock_query_builder()) builder.set_model(self.get_mock_model()) builder.get_query().where = mock.MagicMock() - builder.first = mock.MagicMock(return_value='baz') + builder.first = mock.MagicMock(return_value="baz") - result = builder.find('bar', ['column']) + result = builder.find("bar", ["column"]) - builder.get_query().where.assert_called_once_with( - 'foo_table.foo', '=', 'bar' - ) - self.assertEqual('baz', result) + builder.get_query().where.assert_called_once_with("foo_table.foo", "=", "bar") + self.assertEqual("baz", result) def test_find_or_new_model_found(self): model = self.get_mock_model() - model.find_or_new = mock.MagicMock(return_value='baz') + model.find_or_new = mock.MagicMock(return_value="baz") builder = Builder(self.get_mock_query_builder()) builder.set_model(model) builder.get_query().where = mock.MagicMock() - builder.first = mock.MagicMock(return_value='baz') + builder.first = mock.MagicMock(return_value="baz") - expected = model.find_or_new('bar', ['column']) - result = builder.find('bar', ['column']) + expected = model.find_or_new("bar", ["column"]) + result = builder.find("bar", ["column"]) - builder.get_query().where.assert_called_once_with( - 'foo_table.foo', '=', 'bar' - ) + builder.get_query().where.assert_called_once_with("foo_table.foo", "=", "bar") self.assertEqual(expected, result) def test_find_or_new_model_not_found(self): @@ -58,12 +53,10 @@ def test_find_or_new_model_not_found(self): builder.get_query().where = mock.MagicMock() builder.first = mock.MagicMock(return_value=None) - result = model.find_or_new('bar', ['column']) - find_result = builder.find('bar', ['column']) + result = model.find_or_new("bar", ["column"]) + find_result = builder.find("bar", ["column"]) - builder.get_query().where.assert_called_once_with( - 'foo_table.foo', '=', 'bar' - ) + builder.get_query().where.assert_called_once_with("foo_table.foo", "=", "bar") self.assertIsNone(find_result) self.assertIsInstance(result, Model) @@ -75,20 +68,11 @@ def test_find_or_fail_raises_model_not_found_exception(self): builder.get_query().where = mock.MagicMock() builder.first = mock.MagicMock(return_value=None) - self.assertRaises( - ModelNotFound, - builder.find_or_fail, - 'bar', - ['column'] - ) + self.assertRaises(ModelNotFound, builder.find_or_fail, "bar", ["column"]) - builder.get_query().where.assert_called_once_with( - 'foo_table.foo', '=', 'bar' - ) + builder.get_query().where.assert_called_once_with("foo_table.foo", "=", "bar") - builder.first.assert_called_once_with( - ['column'] - ) + builder.first.assert_called_once_with(["column"]) def test_find_or_fail_with_many_raises_model_not_found_exception(self): model = self.get_mock_model() @@ -98,20 +82,11 @@ def test_find_or_fail_with_many_raises_model_not_found_exception(self): builder.get_query().where_in = mock.MagicMock() builder.get = mock.MagicMock(return_value=Collection([1])) - self.assertRaises( - ModelNotFound, - builder.find_or_fail, - [1, 2], - ['column'] - ) + self.assertRaises(ModelNotFound, builder.find_or_fail, [1, 2], ["column"]) - builder.get_query().where_in.assert_called_once_with( - 'foo_table.foo', [1, 2] - ) + builder.get_query().where_in.assert_called_once_with("foo_table.foo", [1, 2]) - builder.get.assert_called_once_with( - ['column'] - ) + builder.get.assert_called_once_with(["column"]) def test_first_or_fail_raises_model_not_found_exception(self): model = self.get_mock_model() @@ -120,15 +95,9 @@ def test_first_or_fail_raises_model_not_found_exception(self): builder.set_model(model) builder.first = mock.MagicMock(return_value=None) - self.assertRaises( - ModelNotFound, - builder.first_or_fail, - ['column'] - ) + self.assertRaises(ModelNotFound, builder.first_or_fail, ["column"]) - builder.first.assert_called_once_with( - ['column'] - ) + builder.first.assert_called_once_with(["column"]) def test_find_with_many(self): model = self.get_mock_model() @@ -136,18 +105,14 @@ def test_find_with_many(self): builder = Builder(self.get_mock_query_builder()) builder.set_model(model) builder.get_query().where_in = mock.MagicMock() - builder.get = mock.MagicMock(return_value='baz') + builder.get = mock.MagicMock(return_value="baz") - result = builder.find([1, 2], ['column']) - self.assertEqual('baz', result) + result = builder.find([1, 2], ["column"]) + self.assertEqual("baz", result) - builder.get_query().where_in.assert_called_once_with( - 'foo_table.foo', [1, 2] - ) + builder.get_query().where_in.assert_called_once_with("foo_table.foo", [1, 2]) - builder.get.assert_called_once_with( - ['column'] - ) + builder.get.assert_called_once_with(["column"]) def test_first(self): model = self.get_mock_model() @@ -155,41 +120,41 @@ def test_first(self): builder = Builder(self.get_mock_query_builder()) builder.set_model(model) builder.take = mock.MagicMock(return_value=builder) - builder.get = mock.MagicMock(return_value=Collection(['bar'])) + builder.get = mock.MagicMock(return_value=Collection(["bar"])) result = builder.first() - self.assertEqual('bar', result) + self.assertEqual("bar", result) - builder.take.assert_called_once_with( - 1 - ) + builder.take.assert_called_once_with(1) - builder.get.assert_called_once_with( - ['*'] - ) + builder.get.assert_called_once_with(["*"]) def test_get_loads_models_and_hydrates_eager_relations(self): flexmock(Builder) builder = Builder(self.get_mock_query_builder()) - builder.should_receive('get_models').with_args(['foo']).and_return(['bar']) - builder.should_receive('eager_load_relations').with_args(['bar']).and_return(['bar', 'baz']) + builder.should_receive("get_models").with_args(["foo"]).and_return(["bar"]) + builder.should_receive("eager_load_relations").with_args(["bar"]).and_return( + ["bar", "baz"] + ) builder.set_model(self.get_mock_model()) - builder.get_model().new_collection = mock.MagicMock(return_value=Collection(['bar', 'baz'])) + builder.get_model().new_collection = mock.MagicMock( + return_value=Collection(["bar", "baz"]) + ) - results = builder.get(['foo']) - self.assertEqual(['bar', 'baz'], results.all()) + results = builder.get(["foo"]) + self.assertEqual(["bar", "baz"], results.all()) - builder.get_model().new_collection.assert_called_with(['bar', 'baz']) + builder.get_model().new_collection.assert_called_with(["bar", "baz"]) def test_get_does_not_eager_relations_when_no_results_are_returned(self): flexmock(Builder) builder = Builder(self.get_mock_query_builder()) - builder.should_receive('get_models').with_args(['foo']).and_return(['bar']) - builder.should_receive('eager_load_relations').with_args(['bar']).and_return([]) + builder.should_receive("get_models").with_args(["foo"]).and_return(["bar"]) + builder.should_receive("eager_load_relations").with_args(["bar"]).and_return([]) builder.set_model(self.get_mock_model()) builder.get_model().new_collection = mock.MagicMock(return_value=Collection([])) - results = builder.get(['foo']) + results = builder.get(["foo"]) self.assertEqual([], results.all()) builder.get_model().new_collection.assert_called_with([]) @@ -197,27 +162,34 @@ def test_get_does_not_eager_relations_when_no_results_are_returned(self): def test_pluck_with_model_found(self): builder = Builder(self.get_mock_query_builder()) - model = {'name': 'foo'} + model = {"name": "foo"} builder.first = mock.MagicMock(return_value=model) - self.assertEqual('foo', builder.pluck('name')) + self.assertEqual("foo", builder.pluck("name")) - builder.first.assert_called_once_with( - ['name'] - ) + builder.first.assert_called_once_with(["name"]) def test_pluck_with_model_not_found(self): builder = Builder(self.get_mock_query_builder()) builder.first = mock.MagicMock(return_value=None) - self.assertIsNone(builder.pluck('name')) + self.assertIsNone(builder.pluck("name")) def test_chunk(self): - builder = Builder(self.get_mock_query_builder()) - results = [Collection(['foo1', 'foo2']), Collection(['foo3']), Collection([])] - builder.for_page = mock.MagicMock(return_value=builder) - builder.get = mock.MagicMock(side_effect=results) + query_builder = self.get_mock_query_builder() + query_results = [["foo1", "foo2"], ["foo3"]] + query_builder.chunk = mock.MagicMock(return_value=query_results) + + builder = Builder(query_builder) + model = self.get_mock_model() + builder.set_model(model) + + results = [Collection(["foo1", "foo2"]), Collection(["foo3"])] + + model.hydrate = mock.MagicMock(return_value=[]) + model.new_collection = mock.MagicMock(side_effect=results) + model.get_connection_name = mock.MagicMock(return_value="foo") i = 0 for result in builder.chunk(2): @@ -225,61 +197,59 @@ def test_chunk(self): i += 1 - builder.for_page.assert_has_calls([ - mock.call(1, 2), - mock.call(2, 2), - mock.call(3, 2) - ]) + self.assertEqual(i, 2) + + query_builder.chunk.assert_has_calls([mock.call(2)]) + model.hydrate.assert_has_calls( + [mock.call(["foo1", "foo2"], "foo"), mock.call(["foo3"], "foo")] + ) + model.new_collection.assert_has_calls([mock.call([]), mock.call([])]) # TODO: lists with get mutators def test_lists_without_model_getters(self): builder = self.get_builder() - builder.get_query().get = mock.MagicMock(return_value=[{'name': 'bar'}, {'name': 'baz'}]) + builder.get_query().get = mock.MagicMock( + return_value=[{"name": "bar"}, {"name": "baz"}] + ) builder.set_model(self.get_mock_model()) builder.get_model().has_get_mutator = mock.MagicMock(return_value=False) - result = builder.lists('name') - self.assertEqual(['bar', 'baz'], result) + result = builder.lists("name") + self.assertEqual(["bar", "baz"], result) - builder.get_query().get.assert_called_once_with(['name']) + builder.get_query().get.assert_called_once_with(["name"]) def test_get_models_hydrates_models(self): builder = Builder(self.get_mock_query_builder()) - records = Collection([{ - 'name': 'john', 'age': 26 - }, { - 'name': 'jane', 'age': 28 - }]) + records = Collection([{"name": "john", "age": 26}, {"name": "jane", "age": 28}]) builder.get_query().get = mock.MagicMock(return_value=records) model = self.get_mock_model() builder.set_model(model) - model.get_connection_name = mock.MagicMock(return_value='foo_connection') - model.hydrate = mock.MagicMock(return_value=Collection(['hydrated'])) - models = builder.get_models(['foo']) + model.get_connection_name = mock.MagicMock(return_value="foo_connection") + model.hydrate = mock.MagicMock(return_value=Collection(["hydrated"])) + models = builder.get_models(["foo"]) - self.assertEqual(models.all(), ['hydrated']) + self.assertEqual(models.all(), ["hydrated"]) model.get_table.assert_called_once_with() model.get_connection_name.assert_called_once_with() - model.hydrate.assert_called_once_with( - records, 'foo_connection' - ) + model.hydrate.assert_called_once_with(records, "foo_connection") def test_macros_are_called_on_builder(self): - builder = Builder(QueryBuilder( - flexmock(Connection), - flexmock(QueryGrammar), - flexmock(QueryProcessor) - )) + builder = Builder( + QueryBuilder( + flexmock(Connection), flexmock(QueryGrammar), flexmock(QueryProcessor) + ) + ) def foo_bar(builder): builder.foobar = True return builder - builder.macro('foo_bar', foo_bar) + builder.macro("foo_bar", foo_bar) result = builder.foo_bar() self.assertEqual(result, builder) @@ -290,29 +260,36 @@ def test_eager_load_relations_load_top_level_relationships(self): builder = Builder(flexmock(QueryBuilder(None, None, None))) nop1 = lambda: None nop2 = lambda: None - builder.set_eager_loads({'foo': nop1, 'foo.bar': nop2}) - builder.should_receive('_load_relation').with_args(['models'], 'foo', nop1).and_return(['foo']) + builder.set_eager_loads({"foo": nop1, "foo.bar": nop2}) + builder.should_receive("_load_relation").with_args( + ["models"], "foo", nop1 + ).and_return(["foo"]) - results = builder.eager_load_relations(['models']) - self.assertEqual(['foo'], results) + results = builder.eager_load_relations(["models"]) + self.assertEqual(["foo"], results) def test_eager_load_accept_queries(self): model = OrmBuilderTestModelCloseRelated() flexmock(Builder) builder = Builder(flexmock(QueryBuilder(None, None, None))) - nop1 = OrmBuilderTestModelFarRelatedStub.where('id', 5) - builder.set_eager_loads({'foo': nop1}) + nop1 = OrmBuilderTestModelFarRelatedStub.where("id", 5) + builder.set_eager_loads({"foo": nop1}) relation = flexmock() - relation.should_receive('add_eager_constraints').once().with_args(['models']) - relation.should_receive('init_relation').once().with_args(['models'], 'foo').and_return(['models']) - relation.should_receive('get_eager').once().and_return(['results']) - relation.should_receive('match').once()\ - .with_args(['models'], ['results'], 'foo').and_return(['foo']) - builder.should_receive('get_relation').once().with_args('foo').and_return(relation) - relation.should_receive('merge_query').with_args(nop1).and_return(relation) + relation.should_receive("add_eager_constraints").once().with_args(["models"]) + relation.should_receive("init_relation").once().with_args( + ["models"], "foo" + ).and_return(["models"]) + relation.should_receive("get_eager").once().and_return(["results"]) + relation.should_receive("match").once().with_args( + ["models"], ["results"], "foo" + ).and_return(["foo"]) + builder.should_receive("get_relation").once().with_args("foo").and_return( + relation + ) + relation.should_receive("merge_query").with_args(nop1).and_return(relation) - results = builder.eager_load_relations(['models']) - self.assertEqual(['foo'], results) + results = builder.eager_load_relations(["models"]) + self.assertEqual(["foo"], results) def test_relationship_eager_load_process(self): proof = flexmock() @@ -322,18 +299,23 @@ def test_relationship_eager_load_process(self): def callback(q): proof.foo = q - builder.set_eager_loads({'orders': callback}) + builder.set_eager_loads({"orders": callback}) relation = flexmock() - relation.should_receive('add_eager_constraints').once().with_args(['models']) - relation.should_receive('init_relation').once().with_args(['models'], 'orders').and_return(['models']) - relation.should_receive('get_eager').once().and_return(['results']) - relation.should_receive('get_query').once().and_return(relation) - relation.should_receive('match').once()\ - .with_args(['models'], ['results'], 'orders').and_return(['models.matched']) - builder.should_receive('get_relation').once().with_args('orders').and_return(relation) - results = builder.eager_load_relations(['models']) - - self.assertEqual(['models.matched'], results) + relation.should_receive("add_eager_constraints").once().with_args(["models"]) + relation.should_receive("init_relation").once().with_args( + ["models"], "orders" + ).and_return(["models"]) + relation.should_receive("get_eager").once().and_return(["results"]) + relation.should_receive("get_query").once().and_return(relation) + relation.should_receive("match").once().with_args( + ["models"], ["results"], "orders" + ).and_return(["models.matched"]) + builder.should_receive("get_relation").once().with_args("orders").and_return( + relation + ) + results = builder.eager_load_relations(["models"]) + + self.assertEqual(["models.matched"], results) self.assertEqual(relation, proof.foo) def test_get_relation_properly_sets_nested_relationships(self): @@ -341,32 +323,32 @@ def test_get_relation_properly_sets_nested_relationships(self): builder = Builder(flexmock(QueryBuilder(None, None, None))) model = flexmock(Model()) relation = flexmock() - model.set_relation('orders', relation) + model.set_relation("orders", relation) builder.set_model(model) relation_query = flexmock() - relation.should_receive('get_query').and_return(relation_query) - relation_query.should_receive('with_').once().with_args({'lines': None, 'lines.details': None}) - builder.set_eager_loads({ - 'orders': None, - 'orders.lines': None, - 'orders.lines.details': None - }) + relation.should_receive("get_query").and_return(relation_query) + relation_query.should_receive("with_").once().with_args( + {"lines": None, "lines.details": None} + ) + builder.set_eager_loads( + {"orders": None, "orders.lines": None, "orders.lines.details": None} + ) - relation = builder.get_relation('orders') + relation = builder.get_relation("orders") def test_query_passthru(self): builder = self.get_builder() - builder.get_query().foobar = mock.MagicMock(return_value='foo') + builder.get_query().foobar = mock.MagicMock(return_value="foo") self.assertIsInstance(builder.foobar(), Builder) self.assertEqual(builder.foobar(), builder) builder = self.get_builder() - builder.get_query().insert = mock.MagicMock(return_value='foo') + builder.get_query().insert = mock.MagicMock(return_value="foo") - self.assertEqual('foo', builder.insert(['bar'])) + self.assertEqual("foo", builder.insert(["bar"])) - builder.get_query().insert.assert_called_once_with(['bar']) + builder.get_query().insert.assert_called_once_with(["bar"]) def test_query_scopes(self): builder = self.get_builder() @@ -381,11 +363,11 @@ def test_query_scopes(self): def test_simple_where(self): builder = self.get_builder() builder.get_query().where = mock.MagicMock() - result = builder.where('foo', '=', 'bar') + result = builder.where("foo", "=", "bar") self.assertEqual(builder, result) - builder.get_query().where.assert_called_once_with('foo', '=', 'bar', 'and') + builder.get_query().where.assert_called_once_with("foo", "=", "bar", "and") def test_nested_where(self): nested_query = self.get_builder() @@ -399,42 +381,48 @@ def test_nested_where(self): result = builder.where(nested_query) self.assertEqual(builder, result) - builder.get_query().add_nested_where_query.assert_called_once_with(nested_raw_query, 'and') + builder.get_query().add_nested_where_query.assert_called_once_with( + nested_raw_query, "and" + ) # TODO: nested query with scopes def test_delete_override(self): builder = self.get_builder() - builder.on_delete(lambda builder_: {'foo': builder_}) + builder.on_delete(lambda builder_: {"foo": builder_}) - self.assertEqual({'foo': builder}, builder.delete()) + self.assertEqual({"foo": builder}, builder.delete()) def test_has_nested(self): - builder = OrmBuilderTestModelParentStub.where_has('foo', lambda q: q.has('bar')) + builder = OrmBuilderTestModelParentStub.where_has("foo", lambda q: q.has("bar")) - result = OrmBuilderTestModelParentStub.has('foo.bar').to_sql() + result = OrmBuilderTestModelParentStub.has("foo.bar").to_sql() self.assertEqual(builder.to_sql(), result) def test_has_nested_with_constraints(self): model = OrmBuilderTestModelParentStub - builder = model.where_has('foo', lambda q: q.where_has('bar', lambda q: q.where('baz', 'bam'))).to_sql() + builder = model.where_has( + "foo", lambda q: q.where_has("bar", lambda q: q.where("baz", "bam")) + ).to_sql() - result = model.where_has('foo.bar', lambda q: q.where('baz', 'bam')).to_sql() + result = model.where_has("foo.bar", lambda q: q.where("baz", "bam")).to_sql() self.assertEqual(builder, result) def test_where_exists_accepts_builder_instance(self): model = OrmBuilderTestModelCloseRelated - builder = model.where_exists(OrmBuilderTestModelFarRelatedStub.where('foo', 'bar')).to_sql() + builder = model.where_exists( + OrmBuilderTestModelFarRelatedStub.where("foo", "bar") + ).to_sql() self.assertEqual( 'SELECT * FROM "orm_builder_test_model_close_relateds" ' 'WHERE EXISTS (SELECT * FROM "orm_builder_test_model_far_related_stubs" WHERE "foo" = ?)', - builder + builder, ) def get_builder(self): @@ -449,17 +437,12 @@ def get_mock_query_builder(self): connection = MockConnection().prepare_mock() processor = MockProcessor().prepare_mock() - builder = MockQueryBuilder( - connection, - QueryGrammar(), - processor - ).prepare_mock() + builder = MockQueryBuilder(connection, QueryGrammar(), processor).prepare_mock() return builder class OratorTestModel(Model): - @classmethod def _boot_columns(cls): return [] @@ -475,21 +458,18 @@ class OrmBuilderTestModelFarRelatedStub(OratorTestModel): class OrmBuilderTestModelScopeStub(OratorTestModel): - @scope def approved(self, query): - query.where('foo', 'bar') + query.where("foo", "bar") class OrmBuilderTestModelCloseRelated(OratorTestModel): - @has_many def bar(self): return OrmBuilderTestModelFarRelatedStub class OrmBuilderTestModelParentStub(OratorTestModel): - @belongs_to def foo(self): return OrmBuilderTestModelCloseRelated diff --git a/tests/orm/test_factory.py b/tests/orm/test_factory.py index 1b751b78..6bb7f3b5 100644 --- a/tests/orm/test_factory.py +++ b/tests/orm/test_factory.py @@ -8,7 +8,6 @@ class FactoryTestCase(OratorTestCase): - @classmethod def setUpClass(cls): Model.set_connection_resolver(DatabaseConnectionResolver()) @@ -24,50 +23,43 @@ def schema(self): return self.connection().get_schema_builder() def setUp(self): - with self.schema().create('users') as table: - table.increments('id') - table.string('name').unique() - table.string('email').unique() - table.boolean('admin').default(True) + with self.schema().create("users") as table: + table.increments("id") + table.string("name").unique() + table.string("email").unique() + table.boolean("admin").default(True) table.timestamps() - with self.schema().create('posts') as table: - table.increments('id') - table.integer('user_id') - table.string('title').unique() - table.text('content').unique() + with self.schema().create("posts") as table: + table.increments("id") + table.integer("user_id") + table.string("title").unique() + table.text("content").unique() table.timestamps() - table.foreign('user_id').references('id').on('users') + table.foreign("user_id").references("id").on("users") self.factory = Factory() @self.factory.define(User) def users_factory(faker): - return { - 'name': faker.name(), - 'email': faker.email(), - 'admin': False - } + return {"name": faker.name(), "email": faker.email(), "admin": False} - @self.factory.define(User, 'admin') + @self.factory.define(User, "admin") def users_factory(faker): attributes = self.factory.raw(User) - attributes.update({'admin': True}) + attributes.update({"admin": True}) return attributes @self.factory.define(Post) def posts_factory(faker): - return { - 'title': faker.sentence(), - 'content': faker.text() - } + return {"title": faker.sentence(), "content": faker.text()} def tearDown(self): - self.schema().drop('posts') - self.schema().drop('users') + self.schema().drop("posts") + self.schema().drop("users") def test_factory_make(self): user = self.factory.make(User) @@ -75,7 +67,7 @@ def test_factory_make(self): self.assertIsInstance(user, User) self.assertIsNotNone(user.name) self.assertIsNotNone(user.email) - self.assertIsNone(User.where('name', user.name).first()) + self.assertIsNone(User.where("name", user.name).first()) def test_factory_create(self): user = self.factory.create(User) @@ -83,15 +75,15 @@ def test_factory_create(self): self.assertIsInstance(user, User) self.assertIsNotNone(user.name) self.assertIsNotNone(user.email) - self.assertIsNotNone(User.where('name', user.name).first()) + self.assertIsNotNone(User.where("name", user.name).first()) def test_factory_create_with_attributes(self): - user = self.factory.create(User, name='foo', email='foo@bar.com') + user = self.factory.create(User, name="foo", email="foo@bar.com") self.assertIsInstance(user, User) - self.assertEqual('foo', user.name) - self.assertEqual('foo@bar.com', user.email) - self.assertIsNotNone(User.where('name', user.name).first()) + self.assertEqual("foo", user.name) + self.assertEqual("foo@bar.com", user.email) + self.assertIsNotNone(User.where("name", user.name).first()) def test_factory_create_with_relations(self): users = self.factory.build(User, 3) @@ -111,19 +103,19 @@ def test_factory_call(self): self.assertEqual(3, len(users)) self.assertFalse(users[0].admin) - admin = self.factory(User, 'admin').create() + admin = self.factory(User, "admin").create() self.assertTrue(admin.admin) - admins = self.factory(User, 'admin', 3).create() + admins = self.factory(User, "admin", 3).create() self.assertEqual(3, len(admins)) self.assertTrue(admins[0].admin) class User(Model): - __guarded__ = ['id'] + __guarded__ = ["id"] - @has_many('user_id') + @has_many("user_id") def posts(self): return Post @@ -132,7 +124,7 @@ class Post(Model): __guarded__ = [] - @belongs_to('user_id') + @belongs_to("user_id") def user(self): return User @@ -145,12 +137,14 @@ def connection(self, name=None): if self._connection: return self._connection - self._connection = SQLiteConnection(SQLiteConnector().connect({'database': ':memory:'})) + self._connection = SQLiteConnection( + SQLiteConnector().connect({"database": ":memory:"}) + ) return self._connection def get_default_connection(self): - return 'default' + return "default" def set_default_connection(self, name): pass diff --git a/tests/orm/test_model.py b/tests/orm/test_model.py index 4bf7480a..d81abaab 100644 --- a/tests/orm/test_model.py +++ b/tests/orm/test_model.py @@ -24,56 +24,54 @@ class OrmModelTestCase(OratorTestCase): - def tearDown(self): flexmock_teardown() def test_attributes_manipulation(self): model = OrmModelStub() - model.name = 'foo' - self.assertEqual('foo', model.name) + model.name = "foo" + self.assertEqual("foo", model.name) del model.name - self.assertFalse(hasattr(model, 'name')) + self.assertFalse(hasattr(model, "name")) - model.list_items = {'name': 'john'} - self.assertEqual({'name': 'john'}, model.list_items) + model.list_items = {"name": "john"} + self.assertEqual({"name": "john"}, model.list_items) attributes = model.get_attributes() - self.assertEqual(json.dumps({'name': 'john'}), attributes['list_items']) + self.assertEqual(json.dumps({"name": "john"}), attributes["list_items"]) def test_dirty_attributes(self): - model = OrmModelStub(foo='1', bar=2, baz=3) + model = OrmModelStub(foo="1", bar=2, baz=3) model.foo = 1 model.bar = 20 model.baz = 30 self.assertTrue(model.is_dirty()) - self.assertTrue(model.is_dirty('foo')) - self.assertTrue(model.is_dirty('bar')) - self.assertTrue(model.is_dirty('baz')) - self.assertTrue(model.is_dirty('foo', 'bar', 'baz')) + self.assertTrue(model.is_dirty("foo")) + self.assertTrue(model.is_dirty("bar")) + self.assertTrue(model.is_dirty("baz")) + self.assertTrue(model.is_dirty("foo", "bar", "baz")) def test_calculated_attributes(self): model = OrmModelStub() - model.password = 'secret' + model.password = "secret" attributes = model.get_attributes() - self.assertFalse('password' in attributes) - self.assertEqual('******', model.password) - self.assertEqual('5ebe2294ecd0e0f08eab7690d2a6ee69', attributes['password_hash']) - self.assertEqual('5ebe2294ecd0e0f08eab7690d2a6ee69', model.password_hash) + self.assertFalse("password" in attributes) + self.assertEqual("******", model.password) + self.assertEqual( + "5ebe2294ecd0e0f08eab7690d2a6ee69", attributes["password_hash"] + ) + self.assertEqual("5ebe2294ecd0e0f08eab7690d2a6ee69", model.password_hash) def test_new_instance_returns_instance_wit_attributes_set(self): model = OrmModelStub() - instance = model.new_instance({'name': 'john'}) + instance = model.new_instance({"name": "john"}) self.assertIsInstance(instance, OrmModelStub) - self.assertEqual('john', instance.name) + self.assertEqual("john", instance.name) def test_hydrate_creates_collection_of_models(self): - data = [ - {'name': 'john'}, - {'name': 'jane'} - ] - collection = OrmModelStub.hydrate(data, 'foo_connection') + data = [{"name": "john"}, {"name": "jane"}] + collection = OrmModelStub.hydrate(data, "foo_connection") self.assertIsInstance(collection, Collection) self.assertEqual(2, len(collection)) @@ -81,10 +79,10 @@ def test_hydrate_creates_collection_of_models(self): self.assertIsInstance(collection[1], OrmModelStub) self.assertEqual(collection[0].get_attributes(), collection[0].get_original()) self.assertEqual(collection[1].get_attributes(), collection[1].get_original()) - self.assertEqual('john', collection[0].name) - self.assertEqual('jane', collection[1].name) - self.assertEqual('foo_connection', collection[0].get_connection_name()) - self.assertEqual('foo_connection', collection[1].get_connection_name()) + self.assertEqual("john", collection[0].name) + self.assertEqual("jane", collection[1].name) + self.assertEqual("foo_connection", collection[0].get_connection_name()) + self.assertEqual("foo_connection", collection[1].get_connection_name()) def test_hydrate_raw_makes_raw_query(self): model = OrmModelHydrateRawStub() @@ -97,22 +95,22 @@ def _set_connection(name): return model - OrmModelHydrateRawStub.set_connection = mock.MagicMock(side_effect=_set_connection) - collection = OrmModelHydrateRawStub.hydrate_raw('SELECT ?', ['foo']) - self.assertEqual('hydrated', collection) - connection.select.assert_called_once_with( - 'SELECT ?', ['foo'] + OrmModelHydrateRawStub.set_connection = mock.MagicMock( + side_effect=_set_connection ) + collection = OrmModelHydrateRawStub.hydrate_raw("SELECT ?", ["foo"]) + self.assertEqual("hydrated", collection) + connection.select.assert_called_once_with("SELECT ?", ["foo"]) def test_create_saves_new_model(self): - model = OrmModelSaveStub.create(name='john') + model = OrmModelSaveStub.create(name="john") self.assertTrue(model.get_saved()) - self.assertEqual('john', model.name) + self.assertEqual("john", model.name) def test_find_method_calls_query_builder_correctly(self): result = OrmModelFindStub.find(1) - self.assertEqual('foo', result) + self.assertEqual("foo", result) def test_find_use_write_connection(self): OrmModelFindWithWriteConnectionStub.on_write_connection().find(1) @@ -120,42 +118,44 @@ def test_find_use_write_connection(self): def test_find_with_list_calls_query_builder_correctly(self): result = OrmModelFindManyStub.find([1, 2]) - self.assertEqual('foo', result) + self.assertEqual("foo", result) def test_destroy_method_calls_query_builder_correctly(self): OrmModelDestroyStub.destroy(1, 2, 3) def test_with_calls_query_builder_correctly(self): - result = OrmModelWithStub.with_('foo', 'bar') - self.assertEqual('foo', result) + result = OrmModelWithStub.with_("foo", "bar") + self.assertEqual("foo", result) def test_update_process(self): query = flexmock(Builder) - query.should_receive('where').once().with_args('id', 1) - query.should_receive('update').once().with_args({'name': 'john'}) + query.should_receive("where").once().with_args("id", 1) + query.should_receive("update").once().with_args({"name": "john"}) model = OrmModelStub() - model.new_query = mock.MagicMock(return_value=Builder(QueryBuilder(None, None, None))) + model.new_query = mock.MagicMock( + return_value=Builder(QueryBuilder(None, None, None)) + ) model._update_timestamps = mock.MagicMock() events = flexmock(Event()) model.__dispatcher__ = events - events.should_receive('fire').once()\ - .with_args('saving: %s' % model.__class__.__name__, model)\ - .and_return(True) - events.should_receive('fire').once()\ - .with_args('updating: %s' % model.__class__.__name__, model)\ - .and_return(True) - events.should_receive('fire').once()\ - .with_args('updated: %s' % model.__class__.__name__, model)\ - .and_return(True) - events.should_receive('fire').once()\ - .with_args('saved: %s' % model.__class__.__name__, model)\ - .and_return(True) + events.should_receive("fire").once().with_args( + "saving: %s" % model.__class__.__name__, model + ).and_return(True) + events.should_receive("fire").once().with_args( + "updating: %s" % model.__class__.__name__, model + ).and_return(True) + events.should_receive("fire").once().with_args( + "updated: %s" % model.__class__.__name__, model + ).and_return(True) + events.should_receive("fire").once().with_args( + "saved: %s" % model.__class__.__name__, model + ).and_return(True) model.id = 1 - model.foo = 'bar' + model.foo = "bar" model.sync_original() - model.name = 'john' + model.name = "john" model.set_exists(True) self.assertTrue(model.save()) @@ -164,32 +164,36 @@ def test_update_process(self): def test_update_process_does_not_override_timestamps(self): query = flexmock(Builder) - query.should_receive('where').once().with_args('id', 1) - query.should_receive('update').once().with_args({'created_at': 'foo', 'updated_at': 'bar'}) + query.should_receive("where").once().with_args("id", 1) + query.should_receive("update").once().with_args( + {"created_at": "foo", "updated_at": "bar"} + ) model = OrmModelStub() - model.new_query = mock.MagicMock(return_value=Builder(QueryBuilder(None, None, None))) + model.new_query = mock.MagicMock( + return_value=Builder(QueryBuilder(None, None, None)) + ) model._update_timestamps = mock.MagicMock() events = flexmock(Event()) model.__dispatcher__ = events - events.should_receive('fire').once()\ - .with_args('saving: %s' % model.__class__.__name__, model)\ - .and_return(True) - events.should_receive('fire').once()\ - .with_args('updating: %s' % model.__class__.__name__, model)\ - .and_return(True) - events.should_receive('fire').once()\ - .with_args('updated: %s' % model.__class__.__name__, model)\ - .and_return(True) - events.should_receive('fire').once()\ - .with_args('saved: %s' % model.__class__.__name__, model)\ - .and_return(True) + events.should_receive("fire").once().with_args( + "saving: %s" % model.__class__.__name__, model + ).and_return(True) + events.should_receive("fire").once().with_args( + "updating: %s" % model.__class__.__name__, model + ).and_return(True) + events.should_receive("fire").once().with_args( + "updated: %s" % model.__class__.__name__, model + ).and_return(True) + events.should_receive("fire").once().with_args( + "saved: %s" % model.__class__.__name__, model + ).and_return(True) model.id = 1 model.sync_original() - model.created_at = 'foo' - model.updated_at = 'bar' + model.created_at = "foo" + model.updated_at = "bar" model.set_exists(True) self.assertTrue(model.save()) @@ -198,90 +202,104 @@ def test_update_process_does_not_override_timestamps(self): def test_creating_with_only_created_at_column(self): query_builder = flexmock(QueryBuilder) - query_builder.should_receive('insert_get_id').once().with_args({'name': 'john'}, 'id').and_return(1) + query_builder.should_receive("insert_get_id").once().with_args( + {"name": "john"}, "id" + ).and_return(1) model = flexmock(OrmModelCreatedAt()) - model.should_receive('new_query').and_return(Builder(QueryBuilder(None, None, None))) - model.should_receive('set_created_at').once() - model.should_receive('set_updated_at').never() - model.name = 'john' + model.should_receive("new_query").and_return( + Builder(QueryBuilder(None, None, None)) + ) + model.should_receive("set_created_at").once() + model.should_receive("set_updated_at").never() + model.name = "john" model.save() def test_creating_with_only_updated_at_column(self): query_builder = flexmock(QueryBuilder) - query_builder.should_receive('insert_get_id').once().with_args({'name': 'john'}, 'id').and_return(1) + query_builder.should_receive("insert_get_id").once().with_args( + {"name": "john"}, "id" + ).and_return(1) model = flexmock(OrmModelUpdatedAt()) - model.should_receive('new_query').and_return(Builder(QueryBuilder(None, None, None))) - model.should_receive('set_created_at').never() - model.should_receive('set_updated_at').once() - model.name = 'john' + model.should_receive("new_query").and_return( + Builder(QueryBuilder(None, None, None)) + ) + model.should_receive("set_created_at").never() + model.should_receive("set_updated_at").once() + model.name = "john" model.save() def test_updating_with_only_created_at_column(self): query = flexmock(Builder) - query.should_receive('where').once().with_args('id', 1) - query.should_receive('update').once().with_args({'name': 'john'}) + query.should_receive("where").once().with_args("id", 1) + query.should_receive("update").once().with_args({"name": "john"}) model = flexmock(OrmModelCreatedAt()) model.id = 1 model.sync_original() model.set_exists(True) - model.should_receive('new_query').and_return(Builder(QueryBuilder(None, None, None))) - model.should_receive('set_created_at').never() - model.should_receive('set_updated_at').never() - model.name = 'john' + model.should_receive("new_query").and_return( + Builder(QueryBuilder(None, None, None)) + ) + model.should_receive("set_created_at").never() + model.should_receive("set_updated_at").never() + model.name = "john" model.save() def test_updating_with_only_updated_at_column(self): query = flexmock(Builder) - query.should_receive('where').once().with_args('id', 1) - query.should_receive('update').once().with_args({'name': 'john'}) + query.should_receive("where").once().with_args("id", 1) + query.should_receive("update").once().with_args({"name": "john"}) model = flexmock(OrmModelUpdatedAt()) model.id = 1 model.sync_original() model.set_exists(True) - model.should_receive('new_query').and_return(Builder(QueryBuilder(None, None, None))) - model.should_receive('set_created_at').never() - model.should_receive('set_updated_at').once() - model.name = 'john' + model.should_receive("new_query").and_return( + Builder(QueryBuilder(None, None, None)) + ) + model.should_receive("set_created_at").never() + model.should_receive("set_updated_at").once() + model.name = "john" model.save() def test_update_is_cancelled_if_updating_event_returns_false(self): model = flexmock(OrmModelStub()) query = flexmock(Builder(flexmock(QueryBuilder(None, None, None)))) - model.should_receive('new_query_without_scopes').once().and_return(query) + model.should_receive("new_query_without_scopes").once().and_return(query) events = flexmock(Event()) model.__dispatcher__ = events - events.should_receive('fire').once()\ - .with_args('saving: %s' % model.__class__.__name__, model)\ - .and_return(True) - events.should_receive('fire').once()\ - .with_args('updating: %s' % model.__class__.__name__, model)\ - .and_return(False) + events.should_receive("fire").once().with_args( + "saving: %s" % model.__class__.__name__, model + ).and_return(True) + events.should_receive("fire").once().with_args( + "updating: %s" % model.__class__.__name__, model + ).and_return(False) model.set_exists(True) - model.foo = 'bar' + model.foo = "bar" self.assertFalse(model.save()) def test_update_process_without_timestamps(self): query = flexmock(Builder) - query.should_receive('where').once().with_args('id', 1) - query.should_receive('update').once().with_args({'name': 'john'}) + query.should_receive("where").once().with_args("id", 1) + query.should_receive("update").once().with_args({"name": "john"}) model = flexmock(OrmModelStub()) model.__timestamps__ = False - model.new_query = mock.MagicMock(return_value=Builder(QueryBuilder(None, None, None))) + model.new_query = mock.MagicMock( + return_value=Builder(QueryBuilder(None, None, None)) + ) model._update_timestamps = mock.MagicMock() events = flexmock(Event()) model.__dispatcher__ = events - model.should_receive('_fire_model_event').and_return(True) + model.should_receive("_fire_model_event").and_return(True) model.id = 1 model.sync_original() - model.name = 'john' + model.name = "john" model.set_exists(True) self.assertTrue(model.save()) @@ -290,32 +308,34 @@ def test_update_process_without_timestamps(self): def test_update_process_uses_old_primary_key(self): query = flexmock(Builder) - query.should_receive('where').once().with_args('id', 1) - query.should_receive('update').once().with_args({'id': 2, 'name': 'john'}) + query.should_receive("where").once().with_args("id", 1) + query.should_receive("update").once().with_args({"id": 2, "name": "john"}) model = OrmModelStub() - model.new_query = mock.MagicMock(return_value=Builder(QueryBuilder(None, None, None))) + model.new_query = mock.MagicMock( + return_value=Builder(QueryBuilder(None, None, None)) + ) model._update_timestamps = mock.MagicMock() events = flexmock(Event()) model.__dispatcher__ = events - events.should_receive('fire').once()\ - .with_args('saving: %s' % model.__class__.__name__, model)\ - .and_return(True) - events.should_receive('fire').once()\ - .with_args('updating: %s' % model.__class__.__name__, model)\ - .and_return(True) - events.should_receive('fire').once()\ - .with_args('updated: %s' % model.__class__.__name__, model)\ - .and_return(True) - events.should_receive('fire').once()\ - .with_args('saved: %s' % model.__class__.__name__, model)\ - .and_return(True) + events.should_receive("fire").once().with_args( + "saving: %s" % model.__class__.__name__, model + ).and_return(True) + events.should_receive("fire").once().with_args( + "updating: %s" % model.__class__.__name__, model + ).and_return(True) + events.should_receive("fire").once().with_args( + "updated: %s" % model.__class__.__name__, model + ).and_return(True) + events.should_receive("fire").once().with_args( + "saved: %s" % model.__class__.__name__, model + ).and_return(True) model.id = 1 model.sync_original() model.id = 2 - model.name = 'john' + model.name = "john" model.set_exists(True) self.assertTrue(model.save()) @@ -324,20 +344,18 @@ def test_update_process_uses_old_primary_key(self): def test_timestamps_are_returned_as_objects(self): model = Model() - model.set_raw_attributes({ - 'created_at': '2015-03-24', - 'updated_at': '2015-03-24' - }) + model.set_raw_attributes( + {"created_at": "2015-03-24", "updated_at": "2015-03-24"} + ) self.assertIsInstance(model.created_at, Pendulum) self.assertIsInstance(model.updated_at, Pendulum) def test_timestamps_are_returned_as_objects_from_timestamps_and_datetime(self): model = Model() - model.set_raw_attributes({ - 'created_at': datetime.datetime.utcnow(), - 'updated_at': time.time() - }) + model.set_raw_attributes( + {"created_at": datetime.datetime.utcnow(), "updated_at": time.time()} + ) self.assertIsInstance(model.created_at, Pendulum) self.assertIsInstance(model.updated_at, Pendulum) @@ -347,8 +365,8 @@ def test_timestamps_are_returned_as_objects_on_create(self): model.unguard() timestamps = { - 'created_at': datetime.datetime.now(), - 'updated_at': datetime.datetime.now() + "created_at": datetime.datetime.now(), + "updated_at": datetime.datetime.now(), } instance = model.new_instance(timestamps) @@ -363,8 +381,8 @@ def test_timestamps_return_none_if_set_to_none(self): model.unguard() timestamps = { - 'created_at': datetime.datetime.now(), - 'updated_at': datetime.datetime.now() + "created_at": datetime.datetime.now(), + "updated_at": datetime.datetime.now(), } instance = model.new_instance(timestamps) @@ -379,26 +397,30 @@ def test_insert_process(self): model = OrmModelStub() query_builder = flexmock(QueryBuilder) - query_builder.should_receive('insert_get_id').once().with_args({'name': 'john'}, 'id').and_return(1) - model.new_query = mock.MagicMock(return_value=Builder(QueryBuilder(None, None, None))) + query_builder.should_receive("insert_get_id").once().with_args( + {"name": "john"}, "id" + ).and_return(1) + model.new_query = mock.MagicMock( + return_value=Builder(QueryBuilder(None, None, None)) + ) model._update_timestamps = mock.MagicMock() events = flexmock(Event()) model.__dispatcher__ = events - events.should_receive('fire').once()\ - .with_args('saving: %s' % model.__class__.__name__, model)\ - .and_return(True) - events.should_receive('fire').once()\ - .with_args('creating: %s' % model.__class__.__name__, model)\ - .and_return(True) - events.should_receive('fire').once()\ - .with_args('created: %s' % model.__class__.__name__, model)\ - .and_return(True) - events.should_receive('fire').once()\ - .with_args('saved: %s' % model.__class__.__name__, model)\ - .and_return(True) - - model.name = 'john' + events.should_receive("fire").once().with_args( + "saving: %s" % model.__class__.__name__, model + ).and_return(True) + events.should_receive("fire").once().with_args( + "creating: %s" % model.__class__.__name__, model + ).and_return(True) + events.should_receive("fire").once().with_args( + "created: %s" % model.__class__.__name__, model + ).and_return(True) + events.should_receive("fire").once().with_args( + "saved: %s" % model.__class__.__name__, model + ).and_return(True) + + model.name = "john" model.set_exists(False) self.assertTrue(model.save()) self.assertEqual(1, model.id) @@ -406,45 +428,47 @@ def test_insert_process(self): self.assertTrue(model._update_timestamps.called) model = OrmModelStub() - query_builder.should_receive('insert').once().with_args({'name': 'john'}) - model.new_query = mock.MagicMock(return_value=Builder(QueryBuilder(None, None, None))) + query_builder.should_receive("insert").once().with_args({"name": "john"}) + model.new_query = mock.MagicMock( + return_value=Builder(QueryBuilder(None, None, None)) + ) model._update_timestamps = mock.MagicMock() model.set_incrementing(False) events = flexmock(Event()) model.__dispatcher__ = events - events.should_receive('fire').once()\ - .with_args('saving: %s' % model.__class__.__name__, model)\ - .and_return(True) - events.should_receive('fire').once()\ - .with_args('creating: %s' % model.__class__.__name__, model)\ - .and_return(True) - events.should_receive('fire').once()\ - .with_args('created: %s' % model.__class__.__name__, model)\ - .and_return(True) - events.should_receive('fire').once()\ - .with_args('saved: %s' % model.__class__.__name__, model)\ - .and_return(True) - - model.name = 'john' + events.should_receive("fire").once().with_args( + "saving: %s" % model.__class__.__name__, model + ).and_return(True) + events.should_receive("fire").once().with_args( + "creating: %s" % model.__class__.__name__, model + ).and_return(True) + events.should_receive("fire").once().with_args( + "created: %s" % model.__class__.__name__, model + ).and_return(True) + events.should_receive("fire").once().with_args( + "saved: %s" % model.__class__.__name__, model + ).and_return(True) + + model.name = "john" model.set_exists(False) self.assertTrue(model.save()) - self.assertFalse(hasattr(model, 'id')) + self.assertFalse(hasattr(model, "id")) self.assertTrue(model.exists) self.assertTrue(model._update_timestamps.called) def test_insert_is_cancelled_if_creating_event_returns_false(self): model = flexmock(OrmModelStub()) query = flexmock(Builder(flexmock(QueryBuilder(None, None, None)))) - model.should_receive('new_query_without_scopes').once().and_return(query) + model.should_receive("new_query_without_scopes").once().and_return(query) events = flexmock(Event()) model.__dispatcher__ = events - events.should_receive('fire').once()\ - .with_args('saving: %s' % model.__class__.__name__, model)\ - .and_return(True) - events.should_receive('fire').once()\ - .with_args('creating: %s' % model.__class__.__name__, model)\ - .and_return(False) + events.should_receive("fire").once().with_args( + "saving: %s" % model.__class__.__name__, model + ).and_return(True) + events.should_receive("fire").once().with_args( + "creating: %s" % model.__class__.__name__, model + ).and_return(False) self.assertFalse(model.save()) self.assertFalse(model.exists) @@ -452,8 +476,8 @@ def test_insert_is_cancelled_if_creating_event_returns_false(self): def test_delete_properly_deletes_model(self): model = OrmModelStub() builder = flexmock(Builder(QueryBuilder(None, None, None))) - builder.should_receive('where').once().with_args('id', 1).and_return(builder) - builder.should_receive('delete').once() + builder.should_receive("where").once().with_args("id", 1).and_return(builder) + builder.should_receive("delete").once() model.new_query = mock.MagicMock(return_value=builder) model.touch_owners = mock.MagicMock() @@ -465,13 +489,19 @@ def test_delete_properly_deletes_model(self): def test_push_no_relations(self): model = flexmock(Model()) - query = flexmock(QueryBuilder(MockConnection().prepare_mock(), QueryGrammar(), QueryProcessor())) + query = flexmock( + QueryBuilder( + MockConnection().prepare_mock(), QueryGrammar(), QueryProcessor() + ) + ) builder = Builder(query) - builder.get_query().should_receive('insert_get_id').once().with_args({'name': 'john'}, 'id').and_return(1) - model.should_receive('new_query').once().and_return(builder) - model.should_receive('_update_timestamps').once() + builder.get_query().should_receive("insert_get_id").once().with_args( + {"name": "john"}, "id" + ).and_return(1) + model.should_receive("new_query").once().and_return(builder) + model.should_receive("_update_timestamps").once() - model.name = 'john' + model.name = "john" model.set_exists(False) self.assertTrue(model.push()) @@ -480,15 +510,21 @@ def test_push_no_relations(self): def test_push_empty_one_relation(self): model = flexmock(Model()) - query = flexmock(QueryBuilder(MockConnection().prepare_mock(), QueryGrammar(), QueryProcessor())) + query = flexmock( + QueryBuilder( + MockConnection().prepare_mock(), QueryGrammar(), QueryProcessor() + ) + ) builder = Builder(query) - builder.get_query().should_receive('insert_get_id').once().with_args({'name': 'john'}, 'id').and_return(1) - model.should_receive('new_query').once().and_return(builder) - model.should_receive('_update_timestamps').once() + builder.get_query().should_receive("insert_get_id").once().with_args( + {"name": "john"}, "id" + ).and_return(1) + model.should_receive("new_query").once().and_return(builder) + model.should_receive("_update_timestamps").once() - model.name = 'john' + model.name = "john" model.set_exists(False) - model.set_relation('relation_one', None) + model.set_relation("relation_one", None) self.assertTrue(model.push()) self.assertEqual(1, model.id) @@ -497,26 +533,40 @@ def test_push_empty_one_relation(self): def test_push_one_relation(self): related1 = flexmock(Model()) - query = flexmock(QueryBuilder(MockConnection().prepare_mock(), QueryGrammar(), QueryProcessor())) + query = flexmock( + QueryBuilder( + MockConnection().prepare_mock(), QueryGrammar(), QueryProcessor() + ) + ) builder = Builder(query) - builder.get_query().should_receive('insert_get_id').once().with_args({'name': 'related1'}, 'id').and_return(2) - related1.should_receive('new_query').once().and_return(builder) - related1.should_receive('_update_timestamps').once() + builder.get_query().should_receive("insert_get_id").once().with_args( + {"name": "related1"}, "id" + ).and_return(2) + related1.should_receive("new_query").once().and_return(builder) + related1.should_receive("_update_timestamps").once() - related1.name = 'related1' + related1.name = "related1" related1.set_exists(False) model = flexmock(Model()) - model.should_receive('resolve_connection').and_return(MockConnection().prepare_mock()) - query = flexmock(QueryBuilder(MockConnection().prepare_mock(), QueryGrammar(), QueryProcessor())) + model.should_receive("resolve_connection").and_return( + MockConnection().prepare_mock() + ) + query = flexmock( + QueryBuilder( + MockConnection().prepare_mock(), QueryGrammar(), QueryProcessor() + ) + ) builder = Builder(query) - builder.get_query().should_receive('insert_get_id').once().with_args({'name': 'john'}, 'id').and_return(1) - model.should_receive('new_query').once().and_return(builder) - model.should_receive('_update_timestamps').once() + builder.get_query().should_receive("insert_get_id").once().with_args( + {"name": "john"}, "id" + ).and_return(1) + model.should_receive("new_query").once().and_return(builder) + model.should_receive("_update_timestamps").once() - model.name = 'john' + model.name = "john" model.set_exists(False) - model.set_relation('relation_one', related1) + model.set_relation("relation_one", related1) self.assertTrue(model.push()) self.assertEqual(1, model.id) @@ -528,15 +578,21 @@ def test_push_one_relation(self): def test_push_empty_many_relation(self): model = flexmock(Model()) - query = flexmock(QueryBuilder(MockConnection().prepare_mock(), QueryGrammar(), QueryProcessor())) + query = flexmock( + QueryBuilder( + MockConnection().prepare_mock(), QueryGrammar(), QueryProcessor() + ) + ) builder = Builder(query) - builder.get_query().should_receive('insert_get_id').once().with_args({'name': 'john'}, 'id').and_return(1) - model.should_receive('new_query').once().and_return(builder) - model.should_receive('_update_timestamps').once() + builder.get_query().should_receive("insert_get_id").once().with_args( + {"name": "john"}, "id" + ).and_return(1) + model.should_receive("new_query").once().and_return(builder) + model.should_receive("_update_timestamps").once() - model.name = 'john' + model.name = "john" model.set_exists(False) - model.set_relation('relation_many', Collection([])) + model.set_relation("relation_many", Collection([])) self.assertTrue(model.push()) self.assertEqual(1, model.id) @@ -545,51 +601,71 @@ def test_push_empty_many_relation(self): def test_push_many_relation(self): related1 = flexmock(Model()) - query = flexmock(QueryBuilder(MockConnection().prepare_mock(), QueryGrammar(), QueryProcessor())) + query = flexmock( + QueryBuilder( + MockConnection().prepare_mock(), QueryGrammar(), QueryProcessor() + ) + ) builder = Builder(query) - builder.get_query().should_receive('insert_get_id').once().with_args({'name': 'related1'}, 'id').and_return(2) - related1.should_receive('new_query').once().and_return(builder) - related1.should_receive('_update_timestamps').once() + builder.get_query().should_receive("insert_get_id").once().with_args( + {"name": "related1"}, "id" + ).and_return(2) + related1.should_receive("new_query").once().and_return(builder) + related1.should_receive("_update_timestamps").once() - related1.name = 'related1' + related1.name = "related1" related1.set_exists(False) related2 = flexmock(Model()) - query = flexmock(QueryBuilder(MockConnection().prepare_mock(), QueryGrammar(), QueryProcessor())) + query = flexmock( + QueryBuilder( + MockConnection().prepare_mock(), QueryGrammar(), QueryProcessor() + ) + ) builder = Builder(query) - builder.get_query().should_receive('insert_get_id').once().with_args({'name': 'related2'}, 'id').and_return(3) - related2.should_receive('new_query').once().and_return(builder) - related2.should_receive('_update_timestamps').once() + builder.get_query().should_receive("insert_get_id").once().with_args( + {"name": "related2"}, "id" + ).and_return(3) + related2.should_receive("new_query").once().and_return(builder) + related2.should_receive("_update_timestamps").once() - related2.name = 'related2' + related2.name = "related2" related2.set_exists(False) model = flexmock(Model()) - model.should_receive('resolve_connection').and_return(MockConnection().prepare_mock()) - query = flexmock(QueryBuilder(MockConnection().prepare_mock(), QueryGrammar(), QueryProcessor())) + model.should_receive("resolve_connection").and_return( + MockConnection().prepare_mock() + ) + query = flexmock( + QueryBuilder( + MockConnection().prepare_mock(), QueryGrammar(), QueryProcessor() + ) + ) builder = Builder(query) - builder.get_query().should_receive('insert_get_id').once().with_args({'name': 'john'}, 'id').and_return(1) - model.should_receive('new_query').once().and_return(builder) - model.should_receive('_update_timestamps').once() + builder.get_query().should_receive("insert_get_id").once().with_args( + {"name": "john"}, "id" + ).and_return(1) + model.should_receive("new_query").once().and_return(builder) + model.should_receive("_update_timestamps").once() - model.name = 'john' + model.name = "john" model.set_exists(False) - model.set_relation('relation_many', Collection([related1, related2])) + model.set_relation("relation_many", Collection([related1, related2])) self.assertTrue(model.push()) self.assertEqual(1, model.id) self.assertTrue(model.exists) self.assertEqual(2, len(model.relation_many)) - self.assertEqual([2, 3], model.relation_many.lists('id')) + self.assertEqual([2, 3], model.relation_many.lists("id")) def test_new_query_returns_orator_query_builder(self): conn = flexmock(Connection) grammar = flexmock(QueryGrammar) processor = flexmock(QueryProcessor) - conn.should_receive('get_query_grammar').and_return(grammar) - conn.should_receive('get_post_processor').and_return(processor) + conn.should_receive("get_query_grammar").and_return(grammar) + conn.should_receive("get_post_processor").and_return(processor) resolver = flexmock(DatabaseManager) - resolver.should_receive('connection').and_return(Connection(None)) + resolver.should_receive("connection").and_return(Connection(None)) OrmModelStub.set_connection_resolver(DatabaseManager({})) model = OrmModelStub() @@ -598,175 +674,174 @@ def test_new_query_returns_orator_query_builder(self): def test_get_and_set_table(self): model = OrmModelStub() - self.assertEqual('stub', model.get_table()) - model.set_table('foo') - self.assertEqual('foo', model.get_table()) + self.assertEqual("stub", model.get_table()) + model.set_table("foo") + self.assertEqual("foo", model.get_table()) def test_get_key_returns_primary_key_value(self): model = OrmModelStub() model.id = 1 self.assertEqual(1, model.get_key()) - self.assertEqual('id', model.get_key_name()) + self.assertEqual("id", model.get_key_name()) def test_connection_management(self): resolver = flexmock(DatabaseManager) - resolver.should_receive('connection').once().with_args('foo').and_return('bar') + resolver.should_receive("connection").once().with_args("foo").and_return("bar") OrmModelStub.set_connection_resolver(DatabaseManager({})) model = OrmModelStub() - model.set_connection('foo') + model.set_connection("foo") - self.assertEqual('bar', model.get_connection()) + self.assertEqual("bar", model.get_connection()) def test_serialize(self): model = OrmModelStub() - model.name = 'foo' + model.name = "foo" model.age = None - model.password = 'password1' - model.set_hidden(['password']) - model.set_relation('names', Collection([OrmModelStub(bar='baz'), OrmModelStub(bam='boom')])) - model.set_relation('partner', OrmModelStub(name='jane')) - model.set_relation('group', None) - model.set_relation('multi', Collection()) + model.password = "password1" + model.set_hidden(["password"]) + model.set_relation( + "names", Collection([OrmModelStub(bar="baz"), OrmModelStub(bam="boom")]) + ) + model.set_relation("partner", OrmModelStub(name="jane")) + model.set_relation("group", None) + model.set_relation("multi", Collection()) d = model.serialize() self.assertIsInstance(d, dict) - self.assertEqual('foo', d['name']) - self.assertEqual('baz', d['names'][0]['bar']) - self.assertEqual('boom', d['names'][1]['bam']) - self.assertEqual('jane', d['partner']['name']) - self.assertIsNone(d['group']) - self.assertEqual([], d['multi']) - self.assertIsNone(d['age']) - self.assertNotIn('password', d) - - model.set_appends(['appendable']) + self.assertEqual("foo", d["name"]) + self.assertEqual("baz", d["names"][0]["bar"]) + self.assertEqual("boom", d["names"][1]["bam"]) + self.assertEqual("jane", d["partner"]["name"]) + self.assertIsNone(d["group"]) + self.assertEqual([], d["multi"]) + self.assertIsNone(d["age"]) + self.assertNotIn("password", d) + + model.set_appends(["appendable"]) d = model.to_dict() - self.assertEqual('appended', d['appendable']) + self.assertEqual("appended", d["appendable"]) def test_to_dict_includes_default_formatted_timestamps(self): model = Model() - model.set_raw_attributes({ - 'created_at': '2015-03-24', - 'updated_at': '2015-03-25' - }) + model.set_raw_attributes( + {"created_at": "2015-03-24", "updated_at": "2015-03-25"} + ) d = model.to_dict() - self.assertEqual('2015-03-24T00:00:00+00:00', d['created_at']) - self.assertEqual('2015-03-25T00:00:00+00:00', d['updated_at']) + self.assertEqual("2015-03-24T00:00:00+00:00", d["created_at"]) + self.assertEqual("2015-03-25T00:00:00+00:00", d["updated_at"]) def test_to_dict_includes_custom_formatted_timestamps(self): class Stub(Model): - def get_date_format(self): - return '%d-%m-%-y' + return "%d-%m-%-y" - flexmock(Stub).should_receive('_boot_columns').and_return(['created_at', 'updated_at']) + flexmock(Stub).should_receive("_boot_columns").and_return( + ["created_at", "updated_at"] + ) model = Stub() - model.set_raw_attributes({ - 'created_at': '2015-03-24', - 'updated_at': '2015-03-25' - }) + model.set_raw_attributes( + {"created_at": "2015-03-24", "updated_at": "2015-03-25"} + ) d = model.to_dict() - self.assertEqual('24-03-15', d['created_at']) - self.assertEqual('25-03-15', d['updated_at']) + self.assertEqual("24-03-15", d["created_at"]) + self.assertEqual("25-03-15", d["updated_at"]) def test_visible_creates_dict_whitelist(self): model = OrmModelStub() - model.set_visible(['name']) - model.name = 'John' + model.set_visible(["name"]) + model.name = "John" model.age = 28 d = model.to_dict() - self.assertEqual({'name': 'John'}, d) + self.assertEqual({"name": "John"}, d) def test_hidden_can_also_exclude_relationships(self): model = OrmModelStub() - model.name = 'John' - model.set_relation('foo', ['bar']) - model.set_hidden(['foo', 'list_items', 'password']) + model.name = "John" + model.set_relation("foo", ["bar"]) + model.set_hidden(["foo", "list_items", "password"]) d = model.to_dict() - self.assertEqual({'name': 'John'}, d) + self.assertEqual({"name": "John"}, d) def test_to_dict_uses_mutators(self): model = OrmModelStub() model.list_items = [1, 2, 3] d = model.to_dict() - self.assertEqual([1, 2, 3], d['list_items']) + self.assertEqual([1, 2, 3], d["list_items"]) model = OrmModelStub(list_items=[1, 2, 3]) d = model.to_dict() - self.assertEqual([1, 2, 3], d['list_items']) + self.assertEqual([1, 2, 3], d["list_items"]) def test_hidden_are_ignored_when_visible(self): - model = OrmModelStub(name='john', age=28, id='foo') - model.set_visible(['name', 'id']) - model.set_hidden(['name', 'age']) + model = OrmModelStub(name="john", age=28, id="foo") + model.set_visible(["name", "id"]) + model.set_hidden(["name", "age"]) d = model.to_dict() - self.assertIn('name', d) - self.assertIn('id', d) - self.assertNotIn('age', d) + self.assertIn("name", d) + self.assertIn("id", d) + self.assertNotIn("age", d) def test_fillable(self): model = OrmModelStub() - model.fillable(['name', 'age']) - model.fill(name='foo', age=28) - self.assertEqual('foo', model.name) + model.fillable(["name", "age"]) + model.fill(name="foo", age=28) + self.assertEqual("foo", model.name) self.assertEqual(28, model.age) def test_fill_with_dict(self): model = OrmModelStub() - model.fill({'name': 'foo', 'age': 28}) - self.assertEqual('foo', model.name) + model.fill({"name": "foo", "age": 28}) + self.assertEqual("foo", model.name) self.assertEqual(28, model.age) def test_unguard_allows_anything(self): model = OrmModelStub() model.unguard() - model.guard(['*']) - model.fill(name='foo', age=28) - self.assertEqual('foo', model.name) + model.guard(["*"]) + model.fill(name="foo", age=28) + self.assertEqual("foo", model.name) self.assertEqual(28, model.age) model.reguard() def test_underscore_properties_are_not_filled(self): model = OrmModelStub() - model.fill(_foo='bar') + model.fill(_foo="bar") self.assertEqual({}, model.get_attributes()) def test_guarded(self): model = OrmModelStub() - model.guard(['name', 'age']) - model.fill(name='foo', age='bar', foo='bar') - self.assertFalse(hasattr(model, 'name')) - self.assertFalse(hasattr(model, 'age')) - self.assertEqual('bar', model.foo) + model.guard(["name", "age"]) + model.fill(name="foo", age="bar", foo="bar") + self.assertFalse(hasattr(model, "name")) + self.assertFalse(hasattr(model, "age")) + self.assertEqual("bar", model.foo) def test_fillable_overrides_guarded(self): model = OrmModelStub() - model.guard(['name', 'age']) - model.fillable(['age', 'foo']) - model.fill(name='foo', age='bar', foo='bar') - self.assertFalse(hasattr(model, 'name')) - self.assertEqual('bar', model.age) - self.assertEqual('bar', model.foo) + model.guard(["name", "age"]) + model.fillable(["age", "foo"]) + model.fill(name="foo", age="bar", foo="bar") + self.assertFalse(hasattr(model, "name")) + self.assertEqual("bar", model.age) + self.assertEqual("bar", model.foo) def test_global_guarded(self): model = OrmModelStub() - model.guard(['*']) + model.guard(["*"]) self.assertRaises( - MassAssignmentError, - model.fill, - name='foo', age='bar', foo='bar' + MassAssignmentError, model.fill, name="foo", age="bar", foo="bar" ) # TODO: test relations @@ -774,18 +849,16 @@ def test_global_guarded(self): def test_models_assumes_their_name(self): model = OrmModelNoTableStub() - self.assertEqual('orm_model_no_table_stubs', model.get_table()) + self.assertEqual("orm_model_no_table_stubs", model.get_table()) def test_mutator_cache_is_populated(self): model = OrmModelStub() - expected_attributes = sorted([ - 'list_items', - 'password', - 'appendable' - ]) + expected_attributes = sorted(["list_items", "password", "appendable"]) - self.assertEqual(expected_attributes, sorted(list(model._get_mutated_attributes().keys()))) + self.assertEqual( + expected_attributes, sorted(list(model._get_mutated_attributes().keys())) + ) def test_fresh_method(self): model = flexmock(OrmModelStub()) @@ -794,14 +867,14 @@ def test_fresh_method(self): flexmock(Builder) q = flexmock(QueryBuilder(None, None, None)) query = flexmock(Builder(q)) - query.should_receive('where').and_return(query) - query.get_query().should_receive('take').and_return(query) - query.should_receive('get').and_return(Collection([])) - model.should_receive('with_').once().with_args('foo', 'bar').and_return(query) + query.should_receive("where").and_return(query) + query.get_query().should_receive("take").and_return(query) + query.should_receive("get").and_return(Collection([])) + model.should_receive("with_").once().with_args("foo", "bar").and_return(query) - model.fresh(['foo', 'bar']) + model.fresh(["foo", "bar"]) - model.should_receive('with_').once().with_args().and_return(query) + model.should_receive("with_").once().with_args().and_return(query) model.fresh() @@ -809,33 +882,33 @@ def test_clone_model_makes_a_fresh_copy(self): model = OrmModelStub() model.id = 1 model.set_exists(True) - model.first = 'john' - model.last = 'doe' + model.first = "john" + model.last = "doe" model.created_at = model.fresh_timestamp() model.updated_at = model.fresh_timestamp() # TODO: relation clone = model.replicate() - self.assertFalse(hasattr(clone, 'id')) + self.assertFalse(hasattr(clone, "id")) self.assertFalse(clone.exists) - self.assertEqual('john', clone.first) - self.assertEqual('doe', clone.last) - self.assertFalse(hasattr(clone, 'created_at')) - self.assertFalse(hasattr(clone, 'updated_at')) + self.assertEqual("john", clone.first) + self.assertEqual("doe", clone.last) + self.assertFalse(hasattr(clone, "created_at")) + self.assertFalse(hasattr(clone, "updated_at")) # TODO: relation - clone.first = 'jane' + clone.first = "jane" - self.assertEqual('john', model.first) - self.assertEqual('jane', clone.first) + self.assertEqual("john", model.first) + self.assertEqual("jane", clone.first) def test_get_attribute_raise_attribute_error(self): model = OrmModelStub() try: relation = model.incorrect_relation - self.fail('AttributeError not raised') + self.fail("AttributeError not raised") except AttributeError: pass @@ -845,14 +918,14 @@ def test_increment(self): model = OrmModelStub() model.set_exists(True) model.id = 1 - model.sync_original_attribute('id') + model.sync_original_attribute("id") model.foo = 2 - model_mock.should_receive('new_query').and_return(query) - query.should_receive('where').and_return(query) - query.should_receive('increment') + model_mock.should_receive("new_query").and_return(query) + query.should_receive("where").and_return(query) + query.should_receive("increment") - model.public_increment('foo') + model.public_increment("foo") self.assertEqual(3, model.foo) self.assertFalse(model.is_dirty()) @@ -863,31 +936,33 @@ def test_increment(self): def test_timestamps_are_not_update_with_timestamps_false_save_option(self): query = flexmock(Builder) - query.should_receive('where').once().with_args('id', 1) - query.should_receive('update').once().with_args({'name': 'john'}) + query.should_receive("where").once().with_args("id", 1) + query.should_receive("update").once().with_args({"name": "john"}) model = OrmModelStub() - model.new_query = mock.MagicMock(return_value=Builder(QueryBuilder(None, None, None))) + model.new_query = mock.MagicMock( + return_value=Builder(QueryBuilder(None, None, None)) + ) model.id = 1 model.sync_original() - model.name = 'john' + model.name = "john" model.set_exists(True) - self.assertTrue(model.save({'timestamps': False})) - self.assertFalse(hasattr(model, 'updated_at')) + self.assertTrue(model.save({"timestamps": False})) + self.assertFalse(hasattr(model, "updated_at")) model.new_query.assert_called_once_with() def test_casts(self): model = OrmModelCastingStub() - model.first = '3' - model.second = '4.0' + model.first = "3" + model.second = "4.0" model.third = 2.5 model.fourth = 1 model.fifth = 0 - model.sixth = {'foo': 'bar'} - model.seventh = ['foo', 'bar'] - model.eighth = {'foo': 'bar'} + model.sixth = {"foo": "bar"} + model.seventh = ["foo", "bar"] + model.eighth = {"foo": "bar"} self.assertIsInstance(model.first, int) self.assertIsInstance(model.second, float) @@ -899,25 +974,25 @@ def test_casts(self): self.assertIsInstance(model.eighth, dict) self.assertTrue(model.fourth) self.assertFalse(model.fifth) - self.assertEqual({'foo': 'bar'}, model.sixth) - self.assertEqual({'foo': 'bar'}, model.eighth) - self.assertEqual(['foo', 'bar'], model.seventh) + self.assertEqual({"foo": "bar"}, model.sixth) + self.assertEqual({"foo": "bar"}, model.eighth) + self.assertEqual(["foo", "bar"], model.seventh) d = model.to_dict() - self.assertIsInstance(d['first'], int) - self.assertIsInstance(d['second'], float) - self.assertIsInstance(d['third'], basestring) - self.assertIsInstance(d['fourth'], bool) - self.assertIsInstance(d['fifth'], bool) - self.assertIsInstance(d['sixth'], dict) - self.assertIsInstance(d['seventh'], list) - self.assertIsInstance(d['eighth'], dict) - self.assertTrue(d['fourth']) - self.assertFalse(d['fifth']) - self.assertEqual({'foo': 'bar'}, d['sixth']) - self.assertEqual({'foo': 'bar'}, d['eighth']) - self.assertEqual(['foo', 'bar'], d['seventh']) + self.assertIsInstance(d["first"], int) + self.assertIsInstance(d["second"], float) + self.assertIsInstance(d["third"], basestring) + self.assertIsInstance(d["fourth"], bool) + self.assertIsInstance(d["fifth"], bool) + self.assertIsInstance(d["sixth"], dict) + self.assertIsInstance(d["seventh"], list) + self.assertIsInstance(d["eighth"], dict) + self.assertTrue(d["fourth"]) + self.assertFalse(d["fifth"]) + self.assertEqual({"foo": "bar"}, d["sixth"]) + self.assertEqual({"foo": "bar"}, d["eighth"]) + self.assertEqual(["foo", "bar"], d["seventh"]) def test_casts_preserve_null(self): model = OrmModelCastingStub() @@ -941,57 +1016,57 @@ def test_casts_preserve_null(self): d = model.to_dict() - self.assertIsNone(d['first']) - self.assertIsNone(d['second']) - self.assertIsNone(d['third']) - self.assertIsNone(d['fourth']) - self.assertIsNone(d['fifth']) - self.assertIsNone(d['sixth']) - self.assertIsNone(d['seventh']) - self.assertIsNone(d['eighth']) + self.assertIsNone(d["first"]) + self.assertIsNone(d["second"]) + self.assertIsNone(d["third"]) + self.assertIsNone(d["fourth"]) + self.assertIsNone(d["fifth"]) + self.assertIsNone(d["sixth"]) + self.assertIsNone(d["seventh"]) + self.assertIsNone(d["eighth"]) def test_get_foreign_key(self): model = OrmModelStub() - model.set_table('stub') + model.set_table("stub") - self.assertEqual('stub_id', model.get_foreign_key()) + self.assertEqual("stub_id", model.get_foreign_key()) def test_default_values(self): model = OrmModelDefaultAttributes() - self.assertEqual('bar', model.foo) + self.assertEqual("bar", model.foo) def test_get_morph_name(self): model = OrmModelStub() - self.assertEqual('stub', model.get_morph_name()) + self.assertEqual("stub", model.get_morph_name()) class OrmModelStub(Model): - __table__ = 'stub' + __table__ = "stub" __guarded__ = [] @accessor def list_items(self): - return json.loads(self.get_raw_attribute('list_items')) + return json.loads(self.get_raw_attribute("list_items")) @list_items.mutator def set_list_items(self, value): - self.set_raw_attribute('list_items', json.dumps(value)) + self.set_raw_attribute("list_items", json.dumps(value)) @mutator def password(self, value): - self.set_raw_attribute('password_hash', hashlib.md5(value.encode()).hexdigest()) + self.set_raw_attribute("password_hash", hashlib.md5(value.encode()).hexdigest()) @password.accessor def get_password(self): - return '******' + return "******" @accessor def appendable(self): - return 'appended' + return "appended" def public_increment(self, column, amount=1): return self._increment(column, amount) @@ -1001,24 +1076,22 @@ def get_dates(self): class OrmModelHydrateRawStub(Model): - @classmethod def hydrate(cls, items, connection=None): - return 'hydrated' + return "hydrated" class OrmModelWithStub(Model): - def new_query(self): mock = flexmock(Builder(None)) - mock.should_receive('with_').once().with_args('foo', 'bar').and_return('foo') + mock.should_receive("with_").once().with_args("foo", "bar").and_return("foo") return mock class OrmModelSaveStub(Model): - __table__ = 'save_stub' + __table__ = "save_stub" __guarded__ = [] @@ -1033,42 +1106,42 @@ def get_saved(self): class OrmModelFindStub(Model): - def new_query(self): - flexmock(Builder).should_receive('find').once().with_args(1, ['*']).and_return('foo') + flexmock(Builder).should_receive("find").once().with_args(1, ["*"]).and_return( + "foo" + ) return Builder(None) class OrmModelFindWithWriteConnectionStub(Model): - def new_query(self): mock = flexmock(Builder) mock_query = flexmock(QueryBuilder) - mock_query.should_receive('use_write_connection').once().and_return(flexmock) - mock.should_receive('find').once().with_args(1).and_return('foo') + mock_query.should_receive("use_write_connection").once().and_return(flexmock) + mock.should_receive("find").once().with_args(1).and_return("foo") return Builder(QueryBuilder(None, None, None)) class OrmModelFindManyStub(Model): - def new_query(self): mock = flexmock(Builder) - mock.should_receive('find').once().with_args([1, 2], ['*']).and_return('foo') + mock.should_receive("find").once().with_args([1, 2], ["*"]).and_return("foo") return Builder(QueryBuilder(None, None, None)) class OrmModelDestroyStub(Model): - def new_query(self): mock = flexmock(Builder) model = flexmock() mock_query = flexmock(QueryBuilder) - mock_query.should_receive('where_in').once().with_args('id', [1, 2, 3]).and_return(flexmock) - mock.should_receive('get').once().and_return([model]) - model.should_receive('delete').once() + mock_query.should_receive("where_in").once().with_args( + "id", [1, 2, 3] + ).and_return(flexmock) + mock.should_receive("get").once().and_return([model]) + model.should_receive("delete").once() return Builder(QueryBuilder(None, None, None)) @@ -1081,28 +1154,27 @@ class OrmModelNoTableStub(Model): class OrmModelCastingStub(Model): __casts__ = { - 'first': 'int', - 'second': 'float', - 'third': 'str', - 'fourth': 'bool', - 'fifth': 'boolean', - 'sixth': 'dict', - 'seventh': 'list', - 'eighth': 'json' + "first": "int", + "second": "float", + "third": "str", + "fourth": "bool", + "fifth": "boolean", + "sixth": "dict", + "seventh": "list", + "eighth": "json", } + class OrmModelCreatedAt(Model): - __timestamps__ = ['created_at'] + __timestamps__ = ["created_at"] class OrmModelUpdatedAt(Model): - __timestamps__ = ['updated_at'] + __timestamps__ = ["updated_at"] class OrmModelDefaultAttributes(Model): - __attributes__ = { - 'foo': 'bar' - } + __attributes__ = {"foo": "bar"} diff --git a/tests/orm/test_model_global_scopes.py b/tests/orm/test_model_global_scopes.py index 41955706..30fd7f64 100644 --- a/tests/orm/test_model_global_scopes.py +++ b/tests/orm/test_model_global_scopes.py @@ -8,7 +8,6 @@ class ModelGlobalScopesTestCase(OratorTestCase): - @classmethod def setUpClass(cls): Model.set_connection_resolver(DatabaseConnectionResolver()) @@ -21,10 +20,7 @@ def test_global_scope_is_applied(self): model = GlobalScopesModel() query = model.new_query() - self.assertEqual( - 'SELECT * FROM "table" WHERE "active" = ?', - query.to_sql() - ) + self.assertEqual('SELECT * FROM "table" WHERE "active" = ?', query.to_sql()) self.assertEqual([1], query.get_bindings()) @@ -32,10 +28,7 @@ def test_global_scope_can_be_removed(self): model = GlobalScopesModel() query = model.new_query().without_global_scope(ActiveScope) - self.assertEqual( - 'SELECT * FROM "table"', - query.to_sql() - ) + self.assertEqual('SELECT * FROM "table"', query.to_sql()) self.assertEqual([], query.get_bindings()) @@ -45,19 +38,16 @@ def test_callable_global_scope_is_applied(self): self.assertEqual( 'SELECT * FROM "table" WHERE "active" = ? ORDER BY "name" ASC', - query.to_sql() + query.to_sql(), ) self.assertEqual([1], query.get_bindings()) def test_callable_global_scope_can_be_removed(self): model = CallableGlobalScopesModel() - query = model.new_query().without_global_scope('active_scope') + query = model.new_query().without_global_scope("active_scope") - self.assertEqual( - 'SELECT * FROM "table" ORDER BY "name" ASC', - query.to_sql() - ) + self.assertEqual('SELECT * FROM "table" ORDER BY "name" ASC', query.to_sql()) self.assertEqual([], query.get_bindings()) @@ -67,80 +57,76 @@ def test_global_scope_can_be_removed_after_query_is_executed(self): self.assertEqual( 'SELECT * FROM "table" WHERE "active" = ? ORDER BY "name" ASC', - query.to_sql() + query.to_sql(), ) self.assertEqual([1], query.get_bindings()) - query.without_global_scope('active_scope') + query.without_global_scope("active_scope") - self.assertEqual( - 'SELECT * FROM "table" ORDER BY "name" ASC', - query.to_sql() - ) + self.assertEqual('SELECT * FROM "table" ORDER BY "name" ASC', query.to_sql()) self.assertEqual([], query.get_bindings()) def test_all_global_scopes_can_be_removed(self): model = CallableGlobalScopesModel() query = model.new_query().without_global_scopes() - self.assertEqual( - 'SELECT * FROM "table"', - query.to_sql() - ) + self.assertEqual('SELECT * FROM "table"', query.to_sql()) self.assertEqual([], query.get_bindings()) query = CallableGlobalScopesModel.without_global_scopes() - self.assertEqual( - 'SELECT * FROM "table"', - query.to_sql() - ) + self.assertEqual('SELECT * FROM "table"', query.to_sql()) self.assertEqual([], query.get_bindings()) def test_global_scopes_with_or_where_conditions_are_nested(self): model = CallableGlobalScopesModelWithOr() - query = model.new_query().where('col1', 'val1').or_where('col2', 'val2') + query = model.new_query().where("col1", "val1").or_where("col2", "val2") self.assertEqual( 'SELECT "email", "password" FROM "table" ' 'WHERE ("col1" = ? OR "col2" = ?) AND ("email" = ? OR "email" = ?) ' 'AND ("active" = ?) ORDER BY "name" ASC', - query.to_sql() + query.to_sql(), ) self.assertEqual( - ['val1', 'val2', 'john@doe.com', 'someone@else.com', True], - query.get_bindings() + ["val1", "val2", "john@doe.com", "someone@else.com", True], + query.get_bindings(), ) class CallableGlobalScopesModel(Model): - __table__ = 'table' + __table__ = "table" @classmethod def _boot(cls): - cls.add_global_scope('active_scope', lambda query: query.where('active', 1)) + cls.add_global_scope("active_scope", lambda query: query.where("active", 1)) - cls.add_global_scope(lambda query: query.order_by('name')) + cls.add_global_scope(lambda query: query.order_by("name")) super(CallableGlobalScopesModel, cls)._boot() class CallableGlobalScopesModelWithOr(CallableGlobalScopesModel): - __table__ = 'table' + __table__ = "table" @classmethod def _boot(cls): - cls.add_global_scope('or_scope', lambda q: q.where('email', 'john@doe.com').or_where('email', 'someone@else.com')) + cls.add_global_scope( + "or_scope", + lambda q: q.where("email", "john@doe.com").or_where( + "email", "someone@else.com" + ), + ) - cls.add_global_scope(lambda query: query.select('email', 'password')) + cls.add_global_scope(lambda query: query.select("email", "password")) super(CallableGlobalScopesModelWithOr, cls)._boot() class GlobalScopesModel(Model): - __table__ = 'table' + __table__ = "table" @classmethod def _boot(cls): @@ -150,9 +136,8 @@ def _boot(cls): class ActiveScope(Scope): - def apply(self, builder, model): - return builder.where('active', 1) + return builder.where("active", 1) class DatabaseConnectionResolver(object): @@ -163,12 +148,14 @@ def connection(self, name=None): if self._connection: return self._connection - self._connection = SQLiteConnection(SQLiteConnector().connect({'database': ':memory:'})) + self._connection = SQLiteConnection( + SQLiteConnector().connect({"database": ":memory:"}) + ) return self._connection def get_default_connection(self): - return 'default' + return "default" def set_default_connection(self, name): pass diff --git a/tests/pagination/__init__.py b/tests/pagination/__init__.py index 633f8661..40a96afc 100644 --- a/tests/pagination/__init__.py +++ b/tests/pagination/__init__.py @@ -1,2 +1 @@ # -*- coding: utf-8 -*- - diff --git a/tests/pagination/test_length_ware_paginator.py b/tests/pagination/test_length_ware_paginator.py index bf961e4a..d06a45d5 100644 --- a/tests/pagination/test_length_ware_paginator.py +++ b/tests/pagination/test_length_ware_paginator.py @@ -5,26 +5,25 @@ class LengthAwarePaginatorTestCase(OratorTestCase): - def test_returns_relevant_context(self): - p = LengthAwarePaginator(['item3', 'item4'], 4, 2, 2) + p = LengthAwarePaginator(["item3", "item4"], 4, 2, 2) self.assertEqual(2, p.current_page) self.assertEqual(2, p.last_page) self.assertEqual(4, p.total) self.assertTrue(p.has_pages()) self.assertFalse(p.has_more_pages()) - self.assertEqual(['item3', 'item4'], p.items) + self.assertEqual(["item3", "item4"], p.items) self.assertEqual(2, p.per_page) self.assertIsNone(p.next_page) self.assertEqual(1, p.previous_page) self.assertEqual(3, p.first_item) self.assertEqual(4, p.last_item) - self.assertEqual('item4', p[1]) + self.assertEqual("item4", p[1]) def test_integer_division_for_last_page(self): - p = LengthAwarePaginator(['item3', 'item4'], 5, 2, 2) + p = LengthAwarePaginator(["item3", "item4"], 5, 2, 2) self.assertEqual(2, p.current_page) self.assertEqual(3, p.last_page) diff --git a/tests/pagination/test_paginator.py b/tests/pagination/test_paginator.py index 86662fbd..a38a6a90 100644 --- a/tests/pagination/test_paginator.py +++ b/tests/pagination/test_paginator.py @@ -5,21 +5,20 @@ class PaginatorTestCase(OratorTestCase): - def test_returns_relevant_context(self): - p = Paginator(['item3', 'item4', 'item5'], 2, 2) + p = Paginator(["item3", "item4", "item5"], 2, 2) self.assertEqual(2, p.current_page) self.assertTrue(p.has_pages()) self.assertTrue(p.has_more_pages()) - self.assertEqual(['item3', 'item4'], p.items) + self.assertEqual(["item3", "item4"], p.items) self.assertEqual(2, p.per_page) self.assertEqual(3, p.next_page) self.assertEqual(1, p.previous_page) self.assertEqual(3, p.first_item) self.assertEqual(4, p.last_item) - self.assertEqual('item4', p[1]) + self.assertEqual("item4", p[1]) def test_current_page_resolver(self): def current_page_resolver(): diff --git a/tests/query/__init__.py b/tests/query/__init__.py index 633f8661..40a96afc 100644 --- a/tests/query/__init__.py +++ b/tests/query/__init__.py @@ -1,2 +1 @@ # -*- coding: utf-8 -*- - diff --git a/tests/query/test_query_builder.py b/tests/query/test_query_builder.py index 6b07be2c..26c1c5e3 100644 --- a/tests/query/test_query_builder.py +++ b/tests/query/test_query_builder.py @@ -12,7 +12,7 @@ QueryGrammar, PostgresQueryGrammar, SQLiteQueryGrammar, - MySQLQueryGrammar + MySQLQueryGrammar, ) from orator.query.builder import QueryBuilder from orator.query.expression import QueryExpression @@ -21,1208 +21,1236 @@ class QueryBuilderTestCase(OratorTestCase): - def test_basic_select(self): builder = self.get_builder() - builder.select('*').from_('users') + builder.select("*").from_("users") self.assertEqual('SELECT * FROM "users"', builder.to_sql()) def test_basic_select_use_write_connection(self): builder = self.get_builder() - builder.use_write_connection().select('*').from_('users').get() + builder.use_write_connection().select("*").from_("users").get() builder.get_connection().select.assert_called_once_with( - 'SELECT * FROM "users"', - [], - False + 'SELECT * FROM "users"', [], False ) builder = self.get_builder() - builder.select('*').from_('users').get() + builder.select("*").from_("users").get() builder.get_connection().select.assert_called_once_with( - 'SELECT * FROM "users"', - [], - True + 'SELECT * FROM "users"', [], True ) def test_alias_wrapping_as_whole_constant(self): builder = self.get_builder() - builder.select('x.y as foo.bar').from_('baz') + builder.select("x.y as foo.bar").from_("baz") - self.assertEqual( - 'SELECT "x"."y" AS "foo.bar" FROM "baz"', - builder.to_sql() - ) + self.assertEqual('SELECT "x"."y" AS "foo.bar" FROM "baz"', builder.to_sql()) def test_adding_selects(self): builder = self.get_builder() - builder.select('foo').add_select('bar').add_select('baz', 'boom').from_('users') + builder.select("foo").add_select("bar").add_select("baz", "boom").from_("users") self.assertEqual( - 'SELECT "foo", "bar", "baz", "boom" FROM "users"', - builder.to_sql() + 'SELECT "foo", "bar", "baz", "boom" FROM "users"', builder.to_sql() ) def test_basic_select_with_prefix(self): builder = self.get_builder() - builder.get_grammar().set_table_prefix('prefix_') - builder.select('*').from_('users') + builder.get_grammar().set_table_prefix("prefix_") + builder.select("*").from_("users") - self.assertEqual( - 'SELECT * FROM "prefix_users"', - builder.to_sql() - ) + self.assertEqual('SELECT * FROM "prefix_users"', builder.to_sql()) def test_basic_select_distinct(self): builder = self.get_builder() - builder.distinct().select('foo', 'bar').from_('users') + builder.distinct().select("foo", "bar").from_("users") - self.assertEqual( - 'SELECT DISTINCT "foo", "bar" FROM "users"', - builder.to_sql() - ) + self.assertEqual('SELECT DISTINCT "foo", "bar" FROM "users"', builder.to_sql()) def test_basic_alias(self): builder = self.get_builder() - builder.select('foo as bar').from_('users') + builder.select("foo as bar").from_("users") - self.assertEqual( - 'SELECT "foo" AS "bar" FROM "users"', - builder.to_sql() - ) + self.assertEqual('SELECT "foo" AS "bar" FROM "users"', builder.to_sql()) def test_mysql_wrapping_protects_qutotation_marks(self): builder = self.get_mysql_builder() - builder.select('*').from_('some`table') + builder.select("*").from_("some`table") - self.assertEqual( - 'SELECT * FROM `some``table`', - builder.to_sql() - ) + self.assertEqual("SELECT * FROM `some``table`", builder.to_sql()) def test_where_day_mysql(self): builder = self.get_mysql_builder() - builder.select('*').from_('users').where_day('created_at', '=', 1) + builder.select("*").from_("users").where_day("created_at", "=", 1) self.assertEqual( - 'SELECT * FROM `users` WHERE DAY(`created_at`) = %s', - builder.to_sql() + "SELECT * FROM `users` WHERE DAY(`created_at`) = %s", builder.to_sql() ) self.assertEqual([1], builder.get_bindings()) def test_where_month_mysql(self): builder = self.get_mysql_builder() - builder.select('*').from_('users').where_month('created_at', '=', 5) + builder.select("*").from_("users").where_month("created_at", "=", 5) self.assertEqual( - 'SELECT * FROM `users` WHERE MONTH(`created_at`) = %s', - builder.to_sql() + "SELECT * FROM `users` WHERE MONTH(`created_at`) = %s", builder.to_sql() ) self.assertEqual([5], builder.get_bindings()) def test_where_year_mysql(self): builder = self.get_mysql_builder() - builder.select('*').from_('users').where_year('created_at', '=', 2014) + builder.select("*").from_("users").where_year("created_at", "=", 2014) self.assertEqual( - 'SELECT * FROM `users` WHERE YEAR(`created_at`) = %s', - builder.to_sql() + "SELECT * FROM `users` WHERE YEAR(`created_at`) = %s", builder.to_sql() ) self.assertEqual([2014], builder.get_bindings()) def test_where_day_postgres(self): builder = self.get_postgres_builder() - builder.select('*').from_('users').where_day('created_at', '=', 1) + builder.select("*").from_("users").where_day("created_at", "=", 1) self.assertEqual( - 'SELECT * FROM "users" WHERE DAY("created_at") = %s', - builder.to_sql() + 'SELECT * FROM "users" WHERE DAY("created_at") = %s', builder.to_sql() ) self.assertEqual([1], builder.get_bindings()) def test_where_month_postgres(self): builder = self.get_postgres_builder() - builder.select('*').from_('users').where_month('created_at', '=', 5) + builder.select("*").from_("users").where_month("created_at", "=", 5) self.assertEqual( - 'SELECT * FROM "users" WHERE MONTH("created_at") = %s', - builder.to_sql() + 'SELECT * FROM "users" WHERE MONTH("created_at") = %s', builder.to_sql() ) self.assertEqual([5], builder.get_bindings()) def test_where_year_postgres(self): builder = self.get_postgres_builder() - builder.select('*').from_('users').where_year('created_at', '=', 2014) + builder.select("*").from_("users").where_year("created_at", "=", 2014) self.assertEqual( - 'SELECT * FROM "users" WHERE YEAR("created_at") = %s', - builder.to_sql() + 'SELECT * FROM "users" WHERE YEAR("created_at") = %s', builder.to_sql() ) self.assertEqual([2014], builder.get_bindings()) def test_where_day_sqlite(self): builder = self.get_sqlite_builder() - builder.select('*').from_('users').where_day('created_at', '=', 1) + builder.select("*").from_("users").where_day("created_at", "=", 1) self.assertEqual( 'SELECT * FROM "users" WHERE strftime(\'%d\', "created_at") = ?', - builder.to_sql() + builder.to_sql(), ) self.assertEqual([1], builder.get_bindings()) def test_where_month_sqlite(self): builder = self.get_sqlite_builder() - builder.select('*').from_('users').where_month('created_at', '=', 5) + builder.select("*").from_("users").where_month("created_at", "=", 5) self.assertEqual( 'SELECT * FROM "users" WHERE strftime(\'%m\', "created_at") = ?', - builder.to_sql() + builder.to_sql(), ) self.assertEqual([5], builder.get_bindings()) def test_where_year_sqlite(self): builder = self.get_sqlite_builder() - builder.select('*').from_('users').where_year('created_at', '=', 2014) + builder.select("*").from_("users").where_year("created_at", "=", 2014) self.assertEqual( 'SELECT * FROM "users" WHERE strftime(\'%Y\', "created_at") = ?', - builder.to_sql() + builder.to_sql(), ) self.assertEqual([2014], builder.get_bindings()) def test_where_between(self): builder = self.get_builder() - builder.select('*').from_('users').where_between('id', [1, 2]) + builder.select("*").from_("users").where_between("id", [1, 2]) self.assertEqual( - 'SELECT * FROM "users" WHERE "id" BETWEEN ? AND ?', - builder.to_sql() + 'SELECT * FROM "users" WHERE "id" BETWEEN ? AND ?', builder.to_sql() ) self.assertEqual([1, 2], builder.get_bindings()) builder = self.get_builder() - builder.select('*').from_('users').where_not_between('id', [1, 2]) + builder.select("*").from_("users").where_not_between("id", [1, 2]) self.assertEqual( - 'SELECT * FROM "users" WHERE "id" NOT BETWEEN ? AND ?', - builder.to_sql() + 'SELECT * FROM "users" WHERE "id" NOT BETWEEN ? AND ?', builder.to_sql() ) self.assertEqual([1, 2], builder.get_bindings()) def test_basic_or_where(self): builder = self.get_builder() - builder.select('*').from_('users').where('id', '=', 1).or_where('email', '=', 'foo') + builder.select("*").from_("users").where("id", "=", 1).or_where( + "email", "=", "foo" + ) self.assertEqual( - 'SELECT * FROM "users" WHERE "id" = ? OR "email" = ?', - builder.to_sql() + 'SELECT * FROM "users" WHERE "id" = ? OR "email" = ?', builder.to_sql() ) - self.assertEqual([1, 'foo'], builder.get_bindings()) + self.assertEqual([1, "foo"], builder.get_bindings()) def test_raw_wheres(self): builder = self.get_builder() - builder.select('*').from_('users').where_raw('id = ? or email = ?', [1, 'foo']) + builder.select("*").from_("users").where_raw("id = ? or email = ?", [1, "foo"]) self.assertEqual( - 'SELECT * FROM "users" WHERE id = ? OR email = ?', - builder.to_sql() + 'SELECT * FROM "users" WHERE id = ? OR email = ?', builder.to_sql() ) - self.assertEqual([1, 'foo'], builder.get_bindings()) + self.assertEqual([1, "foo"], builder.get_bindings()) builder = self.get_builder() - builder.select('*').from_('users').where_raw('id = ?', [1]) - self.assertEqual( - 'SELECT * FROM "users" WHERE id = ?', - builder.to_sql() - ) + builder.select("*").from_("users").where_raw("id = ?", [1]) + self.assertEqual('SELECT * FROM "users" WHERE id = ?', builder.to_sql()) self.assertEqual([1], builder.get_bindings()) def test_raw_or_wheres(self): builder = self.get_builder() - builder.select('*').from_('users').where('id', '=', 1).or_where_raw('email = ?', ['foo']) + builder.select("*").from_("users").where("id", "=", 1).or_where_raw( + "email = ?", ["foo"] + ) self.assertEqual( - 'SELECT * FROM "users" WHERE "id" = ? OR email = ?', - builder.to_sql() + 'SELECT * FROM "users" WHERE "id" = ? OR email = ?', builder.to_sql() ) - self.assertEqual([1, 'foo'], builder.get_bindings()) + self.assertEqual([1, "foo"], builder.get_bindings()) def test_basic_where_ins(self): builder = self.get_builder() - builder.select('*').from_('users').where_in('id', [1, 2, 3]) + builder.select("*").from_("users").where_in("id", [1, 2, 3]) self.assertEqual( - 'SELECT * FROM "users" WHERE "id" IN (?, ?, ?)', - builder.to_sql() + 'SELECT * FROM "users" WHERE "id" IN (?, ?, ?)', builder.to_sql() ) self.assertEqual([1, 2, 3], builder.get_bindings()) builder = self.get_builder() - builder.select('*').from_('users').where('id', '=', 1).or_where_in('id', [1, 2, 3]) + builder.select("*").from_("users").where("id", "=", 1).or_where_in( + "id", [1, 2, 3] + ) self.assertEqual( 'SELECT * FROM "users" WHERE "id" = ? OR "id" IN (?, ?, ?)', - builder.to_sql() + builder.to_sql(), ) self.assertEqual([1, 1, 2, 3], builder.get_bindings()) def test_basic_where_not_ins(self): builder = self.get_builder() - builder.select('*').from_('users').where_not_in('id', [1, 2, 3]) + builder.select("*").from_("users").where_not_in("id", [1, 2, 3]) self.assertEqual( - 'SELECT * FROM "users" WHERE "id" NOT IN (?, ?, ?)', - builder.to_sql() + 'SELECT * FROM "users" WHERE "id" NOT IN (?, ?, ?)', builder.to_sql() ) self.assertEqual([1, 2, 3], builder.get_bindings()) builder = self.get_builder() - builder.select('*').from_('users').where('id', '=', 1).or_where_not_in('id', [1, 2, 3]) + builder.select("*").from_("users").where("id", "=", 1).or_where_not_in( + "id", [1, 2, 3] + ) self.assertEqual( 'SELECT * FROM "users" WHERE "id" = ? OR "id" NOT IN (?, ?, ?)', - builder.to_sql() + builder.to_sql(), ) self.assertEqual([1, 1, 2, 3], builder.get_bindings()) def test_empty_where_ins(self): builder = self.get_builder() - builder.select('*').from_('users').where_in('id', []) - self.assertEqual( - 'SELECT * FROM "users" WHERE 0 = 1', - builder.to_sql() - ) + builder.select("*").from_("users").where_in("id", []) + self.assertEqual('SELECT * FROM "users" WHERE 0 = 1', builder.to_sql()) self.assertEqual([], builder.get_bindings()) builder = self.get_builder() - builder.select('*').from_('users').where('id', '=', 1).or_where_in('id', []) + builder.select("*").from_("users").where("id", "=", 1).or_where_in("id", []) self.assertEqual( - 'SELECT * FROM "users" WHERE "id" = ? OR 0 = 1', - builder.to_sql() + 'SELECT * FROM "users" WHERE "id" = ? OR 0 = 1', builder.to_sql() ) self.assertEqual([1], builder.get_bindings()) def test_empty_where_not_ins(self): builder = self.get_builder() - builder.select('*').from_('users').where_not_in('id', []) - self.assertEqual( - 'SELECT * FROM "users" WHERE 1 = 1', - builder.to_sql() - ) + builder.select("*").from_("users").where_not_in("id", []) + self.assertEqual('SELECT * FROM "users" WHERE 1 = 1', builder.to_sql()) self.assertEqual([], builder.get_bindings()) builder = self.get_builder() - builder.select('*').from_('users').where('id', '=', 1).or_where_not_in('id', []) + builder.select("*").from_("users").where("id", "=", 1).or_where_not_in("id", []) self.assertEqual( - 'SELECT * FROM "users" WHERE "id" = ? OR 1 = 1', - builder.to_sql() + 'SELECT * FROM "users" WHERE "id" = ? OR 1 = 1', builder.to_sql() ) self.assertEqual([1], builder.get_bindings()) def test_where_in_accepts_collections(self): builder = self.get_builder() - builder.select('*').from_('users').where_in('id', Collection([1, 2, 3])) + builder.select("*").from_("users").where_in("id", Collection([1, 2, 3])) self.assertEqual( - 'SELECT * FROM "users" WHERE "id" IN (?, ?, ?)', - builder.to_sql() + 'SELECT * FROM "users" WHERE "id" IN (?, ?, ?)', builder.to_sql() ) self.assertEqual([1, 2, 3], builder.get_bindings()) def test_unions(self): builder = self.get_builder() - builder.select('*').from_('users').where('id', '=', 1) - builder.union(self.get_builder().select('*').from_('users').where('id', '=', 2)) + builder.select("*").from_("users").where("id", "=", 1) + builder.union(self.get_builder().select("*").from_("users").where("id", "=", 2)) self.assertEqual( 'SELECT * FROM "users" WHERE "id" = ? UNION SELECT * FROM "users" WHERE "id" = ?', - builder.to_sql() + builder.to_sql(), ) self.assertEqual([1, 2], builder.get_bindings()) builder = self.get_mysql_builder() - builder.select('*').from_('users').where('id', '=', 1) - builder.union(self.get_mysql_builder().select('*').from_('users').where('id', '=', 2)) + builder.select("*").from_("users").where("id", "=", 1) + builder.union( + self.get_mysql_builder().select("*").from_("users").where("id", "=", 2) + ) self.assertEqual( - '(SELECT * FROM `users` WHERE `id` = %s) UNION (SELECT * FROM `users` WHERE `id` = %s)', - builder.to_sql() + "(SELECT * FROM `users` WHERE `id` = %s) UNION (SELECT * FROM `users` WHERE `id` = %s)", + builder.to_sql(), ) self.assertEqual([1, 2], builder.get_bindings()) def test_union_alls(self): builder = self.get_builder() - builder.select('*').from_('users').where('id', '=', 1) - builder.union_all(self.get_builder().select('*').from_('users').where('id', '=', 2)) + builder.select("*").from_("users").where("id", "=", 1) + builder.union_all( + self.get_builder().select("*").from_("users").where("id", "=", 2) + ) self.assertEqual( 'SELECT * FROM "users" WHERE "id" = ? UNION ALL SELECT * FROM "users" WHERE "id" = ?', - builder.to_sql() + builder.to_sql(), ) self.assertEqual([1, 2], builder.get_bindings()) def test_multiple_unions(self): builder = self.get_builder() - builder.select('*').from_('users').where('id', '=', 1) - builder.union(self.get_builder().select('*').from_('users').where('id', '=', 2)) - builder.union(self.get_builder().select('*').from_('users').where('id', '=', 3)) + builder.select("*").from_("users").where("id", "=", 1) + builder.union(self.get_builder().select("*").from_("users").where("id", "=", 2)) + builder.union(self.get_builder().select("*").from_("users").where("id", "=", 3)) self.assertEqual( 'SELECT * FROM "users" WHERE "id" = ? ' 'UNION SELECT * FROM "users" WHERE "id" = ? ' 'UNION SELECT * FROM "users" WHERE "id" = ?', - builder.to_sql() + builder.to_sql(), ) self.assertEqual([1, 2, 3], builder.get_bindings()) def test_multiple_union_alls(self): builder = self.get_builder() - builder.select('*').from_('users').where('id', '=', 1) - builder.union_all(self.get_builder().select('*').from_('users').where('id', '=', 2)) - builder.union_all(self.get_builder().select('*').from_('users').where('id', '=', 3)) + builder.select("*").from_("users").where("id", "=", 1) + builder.union_all( + self.get_builder().select("*").from_("users").where("id", "=", 2) + ) + builder.union_all( + self.get_builder().select("*").from_("users").where("id", "=", 3) + ) self.assertEqual( 'SELECT * FROM "users" WHERE "id" = ? ' 'UNION ALL SELECT * FROM "users" WHERE "id" = ? ' 'UNION ALL SELECT * FROM "users" WHERE "id" = ?', - builder.to_sql() + builder.to_sql(), ) self.assertEqual([1, 2, 3], builder.get_bindings()) def test_union_order_bys(self): builder = self.get_builder() - builder.select('*').from_('users').where('id', '=', 1) - builder.union(self.get_builder().select('*').from_('users').where('id', '=', 2)) - builder.order_by('id', 'desc') + builder.select("*").from_("users").where("id", "=", 1) + builder.union(self.get_builder().select("*").from_("users").where("id", "=", 2)) + builder.order_by("id", "desc") self.assertEqual( 'SELECT * FROM "users" WHERE "id" = ? ' 'UNION SELECT * FROM "users" WHERE "id" = ? ' 'ORDER BY "id" DESC', - builder.to_sql() + builder.to_sql(), ) self.assertEqual([1, 2], builder.get_bindings()) def test_union_limits_and_offsets(self): builder = self.get_builder() - builder.select('*').from_('users').where('id', '=', 1) - builder.union(self.get_builder().select('*').from_('users').where('id', '=', 2)) + builder.select("*").from_("users").where("id", "=", 1) + builder.union(self.get_builder().select("*").from_("users").where("id", "=", 2)) builder.skip(5).take(10) self.assertEqual( 'SELECT * FROM "users" WHERE "id" = ? ' 'UNION SELECT * FROM "users" WHERE "id" = ? ' - 'LIMIT 10 OFFSET 5', - builder.to_sql() + "LIMIT 10 OFFSET 5", + builder.to_sql(), ) self.assertEqual([1, 2], builder.get_bindings()) def test_mysql_union_order_bys(self): builder = self.get_mysql_builder() - builder.select('*').from_('users').where('id', '=', 1) - builder.union(self.get_mysql_builder().select('*').from_('users').where('id', '=', 2)) - builder.order_by('id', 'desc') + builder.select("*").from_("users").where("id", "=", 1) + builder.union( + self.get_mysql_builder().select("*").from_("users").where("id", "=", 2) + ) + builder.order_by("id", "desc") self.assertEqual( - '(SELECT * FROM `users` WHERE `id` = %s) ' - 'UNION (SELECT * FROM `users` WHERE `id` = %s) ' - 'ORDER BY `id` DESC', - builder.to_sql() + "(SELECT * FROM `users` WHERE `id` = %s) " + "UNION (SELECT * FROM `users` WHERE `id` = %s) " + "ORDER BY `id` DESC", + builder.to_sql(), ) self.assertEqual([1, 2], builder.get_bindings()) def test_mysql_union_limits_and_offsets(self): builder = self.get_mysql_builder() - builder.select('*').from_('users').where('id', '=', 1) - builder.union(self.get_mysql_builder().select('*').from_('users').where('id', '=', 2)) + builder.select("*").from_("users").where("id", "=", 1) + builder.union( + self.get_mysql_builder().select("*").from_("users").where("id", "=", 2) + ) builder.skip(5).take(10) self.assertEqual( - '(SELECT * FROM `users` WHERE `id` = %s) ' - 'UNION (SELECT * FROM `users` WHERE `id` = %s) ' - 'LIMIT 10 OFFSET 5', - builder.to_sql() + "(SELECT * FROM `users` WHERE `id` = %s) " + "UNION (SELECT * FROM `users` WHERE `id` = %s) " + "LIMIT 10 OFFSET 5", + builder.to_sql(), ) self.assertEqual([1, 2], builder.get_bindings()) def test_sub_select_where_in(self): builder = self.get_builder() - builder.select('*').from_('users').where_in( - 'id', - self.get_builder().select('id').from_('users').where('age', '>', 25).take(3) + builder.select("*").from_("users").where_in( + "id", + self.get_builder() + .select("id") + .from_("users") + .where("age", ">", 25) + .take(3), ) self.assertEqual( 'SELECT * FROM "users" WHERE "id" IN (SELECT "id" FROM "users" WHERE "age" > ? LIMIT 3)', - builder.to_sql() + builder.to_sql(), ) self.assertEqual([25], builder.get_bindings()) builder = self.get_builder() - builder.select('*').from_('users').where_not_in( - 'id', - self.get_builder().select('id').from_('users').where('age', '>', 25).take(3) + builder.select("*").from_("users").where_not_in( + "id", + self.get_builder() + .select("id") + .from_("users") + .where("age", ">", 25) + .take(3), ) self.assertEqual( 'SELECT * FROM "users" WHERE "id" NOT IN (SELECT "id" FROM "users" WHERE "age" > ? LIMIT 3)', - builder.to_sql() + builder.to_sql(), ) self.assertEqual([25], builder.get_bindings()) def test_basic_where_null(self): builder = self.get_builder() - builder.select('*').from_('users').where_null('id') - self.assertEqual( - 'SELECT * FROM "users" WHERE "id" IS NULL', - builder.to_sql() - ) + builder.select("*").from_("users").where_null("id") + self.assertEqual('SELECT * FROM "users" WHERE "id" IS NULL', builder.to_sql()) self.assertEqual([], builder.get_bindings()) builder = self.get_builder() - builder.select('*').from_('users').where('id', '=', 1).or_where_null('id') + builder.select("*").from_("users").where("id", "=", 1).or_where_null("id") self.assertEqual( - 'SELECT * FROM "users" WHERE "id" = ? OR "id" IS NULL', - builder.to_sql() + 'SELECT * FROM "users" WHERE "id" = ? OR "id" IS NULL', builder.to_sql() ) self.assertEqual([1], builder.get_bindings()) def test_basic_where_not_null(self): builder = self.get_builder() - builder.select('*').from_('users').where_not_null('id') + builder.select("*").from_("users").where_not_null("id") self.assertEqual( - 'SELECT * FROM "users" WHERE "id" IS NOT NULL', - builder.to_sql() + 'SELECT * FROM "users" WHERE "id" IS NOT NULL', builder.to_sql() ) self.assertEqual([], builder.get_bindings()) builder = self.get_builder() - builder.select('*').from_('users').where('id', '=', 1).or_where_not_null('id') + builder.select("*").from_("users").where("id", "=", 1).or_where_not_null("id") self.assertEqual( - 'SELECT * FROM "users" WHERE "id" = ? OR "id" IS NOT NULL', - builder.to_sql() + 'SELECT * FROM "users" WHERE "id" = ? OR "id" IS NOT NULL', builder.to_sql() ) self.assertEqual([1], builder.get_bindings()) def test_group_bys(self): builder = self.get_builder() - builder.select('*').from_('users').group_by('id', 'email') + builder.select("*").from_("users").group_by("id", "email") self.assertEqual( - 'SELECT * FROM "users" GROUP BY "id", "email"', - builder.to_sql() + 'SELECT * FROM "users" GROUP BY "id", "email"', builder.to_sql() ) def test_order_bys(self): builder = self.get_builder() - builder.select('*').from_('users').order_by('email').order_by('age', 'desc') + builder.select("*").from_("users").order_by("email").order_by("age", "desc") self.assertEqual( - 'SELECT * FROM "users" ORDER BY "email" ASC, "age" DESC', - builder.to_sql() + 'SELECT * FROM "users" ORDER BY "email" ASC, "age" DESC', builder.to_sql() ) builder = self.get_builder() - builder.select('*').from_('users').order_by('email').order_by_raw('"age" ? desc', ['foo']) + builder.select("*").from_("users").order_by("email").order_by_raw( + '"age" ? desc', ["foo"] + ) self.assertEqual( - 'SELECT * FROM "users" ORDER BY "email" ASC, "age" ? DESC', - builder.to_sql() + 'SELECT * FROM "users" ORDER BY "email" ASC, "age" ? DESC', builder.to_sql() ) def test_havings(self): builder = self.get_builder() - builder.select('*').from_('users').having('email', '>', 1) - self.assertEqual( - 'SELECT * FROM "users" HAVING "email" > ?', - builder.to_sql() - ) + builder.select("*").from_("users").having("email", ">", 1) + self.assertEqual('SELECT * FROM "users" HAVING "email" > ?', builder.to_sql()) self.assertEqual([1], builder.get_bindings()) builder = self.get_builder() - builder.select('*').from_('users')\ - .or_having('email', '=', 'foo@bar.com')\ - .or_having('email', '=', 'foo2@bar.com') + builder.select("*").from_("users").or_having( + "email", "=", "foo@bar.com" + ).or_having("email", "=", "foo2@bar.com") self.assertEqual( - 'SELECT * FROM "users" HAVING "email" = ? OR "email" = ?', - builder.to_sql() + 'SELECT * FROM "users" HAVING "email" = ? OR "email" = ?', builder.to_sql() ) - self.assertEqual(['foo@bar.com', 'foo2@bar.com'], builder.get_bindings()) + self.assertEqual(["foo@bar.com", "foo2@bar.com"], builder.get_bindings()) builder = self.get_builder() - builder.select('*').from_('users')\ - .group_by('email')\ - .having('email', '>', 1) + builder.select("*").from_("users").group_by("email").having("email", ">", 1) self.assertEqual( 'SELECT * FROM "users" GROUP BY "email" HAVING "email" > ?', - builder.to_sql() + builder.to_sql(), ) self.assertEqual([1], builder.get_bindings()) builder = self.get_builder() - builder.select('email as foo_mail').from_('users')\ - .having('foo_mail', '>', 1) + builder.select("email as foo_mail").from_("users").having("foo_mail", ">", 1) self.assertEqual( 'SELECT "email" AS "foo_mail" FROM "users" HAVING "foo_mail" > ?', - builder.to_sql() + builder.to_sql(), ) self.assertEqual([1], builder.get_bindings()) builder = self.get_builder() - builder.select('category', QueryExpression('count(*) as "total"'))\ - .from_('item')\ - .where('department', '=', 'popular')\ - .group_by('category')\ - .having('total', '>', QueryExpression('3')) + builder.select("category", QueryExpression('count(*) as "total"')).from_( + "item" + ).where("department", "=", "popular").group_by("category").having( + "total", ">", QueryExpression("3") + ) self.assertEqual( 'SELECT "category", count(*) as "total" ' 'FROM "item" ' 'WHERE "department" = ? ' 'GROUP BY "category" ' 'HAVING "total" > 3', - builder.to_sql() + builder.to_sql(), ) - self.assertEqual(['popular'], builder.get_bindings()) + self.assertEqual(["popular"], builder.get_bindings()) builder = self.get_builder() - builder.select('category', QueryExpression('count(*) as "total"'))\ - .from_('item')\ - .where('department', '=', 'popular')\ - .group_by('category')\ - .having('total', '>', 3) + builder.select("category", QueryExpression('count(*) as "total"')).from_( + "item" + ).where("department", "=", "popular").group_by("category").having( + "total", ">", 3 + ) self.assertEqual( 'SELECT "category", count(*) as "total" ' 'FROM "item" ' 'WHERE "department" = ? ' 'GROUP BY "category" ' 'HAVING "total" > ?', - builder.to_sql() + builder.to_sql(), ) - self.assertEqual(['popular', 3], builder.get_bindings()) + self.assertEqual(["popular", 3], builder.get_bindings()) def test_having_followed_by_select_get(self): builder = self.get_builder() - query = 'SELECT "category", count(*) as "total" ' \ - 'FROM "item" ' \ - 'WHERE "department" = ? ' \ - 'GROUP BY "category" ' \ - 'HAVING "total" > ?' - results = [{ - 'category': 'rock', - 'total': 5 - }] + query = ( + 'SELECT "category", count(*) as "total" ' + 'FROM "item" ' + 'WHERE "department" = ? ' + 'GROUP BY "category" ' + 'HAVING "total" > ?' + ) + results = [{"category": "rock", "total": 5}] builder.get_connection().select.return_value = results - builder.get_processor().process_select = mock.MagicMock(side_effect=lambda builder_, results: results) - result = builder.select('category', QueryExpression('count(*) as "total"'))\ - .from_('item')\ - .where('department', '=', 'popular')\ - .group_by('category')\ - .having('total', '>', 3)\ + builder.get_processor().process_select = mock.MagicMock( + side_effect=lambda builder_, results: results + ) + result = ( + builder.select("category", QueryExpression('count(*) as "total"')) + .from_("item") + .where("department", "=", "popular") + .group_by("category") + .having("total", ">", 3) .get() + ) builder.get_connection().select.assert_called_once_with( - query, - ['popular', 3], - True - ) - builder.get_processor().process_select.assert_called_once_with( - builder, - results + query, ["popular", 3], True ) + builder.get_processor().process_select.assert_called_once_with(builder, results) self.assertEqual(results, result) - self.assertEqual(['popular', 3], builder.get_bindings()) + self.assertEqual(["popular", 3], builder.get_bindings()) # Using raw value builder = self.get_builder() - query = 'SELECT "category", count(*) as "total" ' \ - 'FROM "item" ' \ - 'WHERE "department" = ? ' \ - 'GROUP BY "category" ' \ - 'HAVING "total" > 3' + query = ( + 'SELECT "category", count(*) as "total" ' + 'FROM "item" ' + 'WHERE "department" = ? ' + 'GROUP BY "category" ' + 'HAVING "total" > 3' + ) builder.get_connection().select.return_value = results - builder.get_processor().process_select = mock.MagicMock(side_effect=lambda builder_, results: results) - result = builder.select('category', QueryExpression('count(*) as "total"'))\ - .from_('item')\ - .where('department', '=', 'popular')\ - .group_by('category')\ - .having('total', '>', QueryExpression('3'))\ + builder.get_processor().process_select = mock.MagicMock( + side_effect=lambda builder_, results: results + ) + result = ( + builder.select("category", QueryExpression('count(*) as "total"')) + .from_("item") + .where("department", "=", "popular") + .group_by("category") + .having("total", ">", QueryExpression("3")) .get() + ) builder.get_connection().select.assert_called_once_with( - query, - ['popular'], - True - ) - builder.get_processor().process_select.assert_called_once_with( - builder, - results + query, ["popular"], True ) + builder.get_processor().process_select.assert_called_once_with(builder, results) self.assertEqual(results, result) - self.assertEqual(['popular'], builder.get_bindings()) + self.assertEqual(["popular"], builder.get_bindings()) def test_raw_havings(self): builder = self.get_builder() - builder.select('*').from_('users').having_raw('user_foo < user_bar') + builder.select("*").from_("users").having_raw("user_foo < user_bar") self.assertEqual( - 'SELECT * FROM "users" HAVING user_foo < user_bar', - builder.to_sql() + 'SELECT * FROM "users" HAVING user_foo < user_bar', builder.to_sql() ) self.assertEqual([], builder.get_bindings()) builder = self.get_builder() - builder.select('*').from_('users').having('foo', '=', 1).or_having_raw('user_foo < user_bar') + builder.select("*").from_("users").having("foo", "=", 1).or_having_raw( + "user_foo < user_bar" + ) self.assertEqual( 'SELECT * FROM "users" HAVING "foo" = ? OR user_foo < user_bar', - builder.to_sql() + builder.to_sql(), ) self.assertEqual([1], builder.get_bindings()) def test_limits_and_offsets(self): builder = self.get_builder() - builder.select('*').from_('users').offset(5).limit(10) + builder.select("*").from_("users").offset(5).limit(10) self.assertEqual('SELECT * FROM "users" LIMIT 10 OFFSET 5', builder.to_sql()) builder = self.get_builder() - builder.select('*').from_('users').skip(5).take(10) + builder.select("*").from_("users").skip(5).take(10) self.assertEqual('SELECT * FROM "users" LIMIT 10 OFFSET 5', builder.to_sql()) builder = self.get_builder() - builder.select('*').from_('users').skip(-5).take(10) + builder.select("*").from_("users").skip(-5).take(10) self.assertEqual('SELECT * FROM "users" LIMIT 10 OFFSET 0', builder.to_sql()) builder = self.get_builder() - builder.select('*').from_('users').for_page(2, 15) + builder.select("*").from_("users").for_page(2, 15) self.assertEqual('SELECT * FROM "users" LIMIT 15 OFFSET 15', builder.to_sql()) builder = self.get_builder() - builder.select('*').from_('users').for_page(-2, 15) + builder.select("*").from_("users").for_page(-2, 15) self.assertEqual('SELECT * FROM "users" LIMIT 15 OFFSET 0', builder.to_sql()) def test_where_shortcut(self): builder = self.get_builder() - builder.select('*').from_('users').where('id', 1).or_where('name', 'foo') + builder.select("*").from_("users").where("id", 1).or_where("name", "foo") self.assertEqual( - 'SELECT * FROM "users" WHERE "id" = ? OR "name" = ?', - builder.to_sql() + 'SELECT * FROM "users" WHERE "id" = ? OR "name" = ?', builder.to_sql() ) - self.assertEqual([1, 'foo'], builder.get_bindings()) + self.assertEqual([1, "foo"], builder.get_bindings()) def test_multiple_wheres_in_list(self): builder = self.get_builder() - builder.select('*').from_('users').where([['name', '=', 'bar'], ['age', '=', 25]]) + builder.select("*").from_("users").where( + [["name", "=", "bar"], ["age", "=", 25]] + ) self.assertEqual( - 'SELECT * FROM "users" WHERE ("name" = ? AND "age" = ?)', - builder.to_sql() + 'SELECT * FROM "users" WHERE ("name" = ? AND "age" = ?)', builder.to_sql() ) - self.assertEqual(['bar', 25], builder.get_bindings()) + self.assertEqual(["bar", 25], builder.get_bindings()) def test_multiple_wheres_in_list_with_exception(self): builder = self.get_builder() try: - builder.select('*').from_('users').where([['name', 'bar'], ['age', '=', 25]]) - self.fail('Builder has not raised Argument Error for invalid no. of values in where list') + builder.select("*").from_("users").where( + [["name", "bar"], ["age", "=", 25]] + ) + self.fail( + "Builder has not raised Argument Error for invalid no. of values in where list" + ) except ArgumentError: self.assertTrue(True) try: - builder.select('*').from_('users').where(['name', 'bar']) - self.fail('Builder has not raised Argument Error for invalid datatype in where list') + builder.select("*").from_("users").where(["name", "bar"]) + self.fail( + "Builder has not raised Argument Error for invalid datatype in where list" + ) except ArgumentError: self.assertTrue(True) - def test_nested_wheres(self): builder = self.get_builder() - builder.select('*').from_('users').where('email', '=', 'foo').or_where( - builder.new_query().where('name', '=', 'bar').where('age', '=', 25) + builder.select("*").from_("users").where("email", "=", "foo").or_where( + builder.new_query().where("name", "=", "bar").where("age", "=", 25) ) self.assertEqual( 'SELECT * FROM "users" WHERE "email" = ? OR ("name" = ? AND "age" = ?)', - builder.to_sql() + builder.to_sql(), ) - self.assertEqual(['foo', 'bar', 25], builder.get_bindings()) + self.assertEqual(["foo", "bar", 25], builder.get_bindings()) def test_full_sub_selects(self): builder = self.get_builder() - builder.select('*').from_('users').where('email', '=', 'foo').or_where( - 'id', '=', builder.new_query().select(QueryExpression('max(id)')).from_('users').where('email', '=', 'bar') + builder.select("*").from_("users").where("email", "=", "foo").or_where( + "id", + "=", + builder.new_query() + .select(QueryExpression("max(id)")) + .from_("users") + .where("email", "=", "bar"), ) self.assertEqual( 'SELECT * FROM "users" WHERE "email" = ? OR "id" = (SELECT max(id) FROM "users" WHERE "email" = ?)', - builder.to_sql() + builder.to_sql(), ) - self.assertEqual(['foo', 'bar'], builder.get_bindings()) + self.assertEqual(["foo", "bar"], builder.get_bindings()) def test_where_exists(self): builder = self.get_builder() - builder.select('*').from_('orders').where_exists( - self.get_builder().select('*').from_('products').where('products.id', '=', QueryExpression('"orders"."id"')) + builder.select("*").from_("orders").where_exists( + self.get_builder() + .select("*") + .from_("products") + .where("products.id", "=", QueryExpression('"orders"."id"')) ) self.assertEqual( 'SELECT * FROM "orders" ' 'WHERE EXISTS (SELECT * FROM "products" WHERE "products"."id" = "orders"."id")', - builder.to_sql() + builder.to_sql(), ) self.assertEqual([], builder.get_bindings()) builder = self.get_builder() - builder.select('*').from_('orders').where_not_exists( - self.get_builder().select('*').from_('products').where('products.id', '=', QueryExpression('"orders"."id"')) + builder.select("*").from_("orders").where_not_exists( + self.get_builder() + .select("*") + .from_("products") + .where("products.id", "=", QueryExpression('"orders"."id"')) ) self.assertEqual( 'SELECT * FROM "orders" ' 'WHERE NOT EXISTS (SELECT * FROM "products" WHERE "products"."id" = "orders"."id")', - builder.to_sql() + builder.to_sql(), ) self.assertEqual([], builder.get_bindings()) builder = self.get_builder() - builder.select('*').from_('orders').where('id', '=', 1).or_where_exists( - self.get_builder().select('*').from_('products').where('products.id', '=', QueryExpression('"orders"."id"')) + builder.select("*").from_("orders").where("id", "=", 1).or_where_exists( + self.get_builder() + .select("*") + .from_("products") + .where("products.id", "=", QueryExpression('"orders"."id"')) ) self.assertEqual( 'SELECT * FROM "orders" WHERE "id" = ? ' 'OR EXISTS (SELECT * FROM "products" WHERE "products"."id" = "orders"."id")', - builder.to_sql() + builder.to_sql(), ) self.assertEqual([1], builder.get_bindings()) builder = self.get_builder() - builder.select('*').from_('orders').where('id', '=', 1).or_where_not_exists( - self.get_builder().select('*').from_('products').where('products.id', '=', QueryExpression('"orders"."id"')) + builder.select("*").from_("orders").where("id", "=", 1).or_where_not_exists( + self.get_builder() + .select("*") + .from_("products") + .where("products.id", "=", QueryExpression('"orders"."id"')) ) self.assertEqual( 'SELECT * FROM "orders" WHERE "id" = ? ' 'OR NOT EXISTS (SELECT * FROM "products" WHERE "products"."id" = "orders"."id")', - builder.to_sql() + builder.to_sql(), ) self.assertEqual([1], builder.get_bindings()) def test_basic_joins(self): builder = self.get_builder() - builder.select('*').from_('users')\ - .join('contacts', 'users.id', '=', 'contacts.id')\ - .left_join('photos', 'users.id', '=', 'photos.user_id') + builder.select("*").from_("users").join( + "contacts", "users.id", "=", "contacts.id" + ).left_join("photos", "users.id", "=", "photos.user_id") self.assertEqual( 'SELECT * FROM "users" ' 'INNER JOIN "contacts" ON "users"."id" = "contacts"."id" ' 'LEFT JOIN "photos" ON "users"."id" = "photos"."user_id"', - builder.to_sql() + builder.to_sql(), ) builder = self.get_builder() - builder.select('*').from_('users')\ - .left_join_where('photos', 'users.id', '=', 3)\ - .join_where('photos', 'users.id', '=', 'foo') + builder.select("*").from_("users").left_join_where( + "photos", "users.id", "=", 3 + ).join_where("photos", "users.id", "=", "foo") self.assertEqual( 'SELECT * FROM "users" ' 'LEFT JOIN "photos" ON "users"."id" = ? ' 'INNER JOIN "photos" ON "users"."id" = ?', - builder.to_sql() + builder.to_sql(), ) - self.assertEqual([3, 'foo'], builder.get_bindings()) + self.assertEqual([3, "foo"], builder.get_bindings()) def test_complex_joins(self): builder = self.get_builder() - builder.select('*').from_('users').join( - JoinClause('contacts') - .on('users.id', '=', 'contacts.id') - .or_on('users.name', '=', 'contacts.name') + builder.select("*").from_("users").join( + JoinClause("contacts") + .on("users.id", "=", "contacts.id") + .or_on("users.name", "=", "contacts.name") ) self.assertEqual( 'SELECT * FROM "users" ' 'INNER JOIN "contacts" ON "users"."id" = "contacts"."id" ' 'OR "users"."name" = "contacts"."name"', - builder.to_sql() + builder.to_sql(), ) builder = self.get_builder() - builder.select('*').from_('users').join( - JoinClause('contacts') - .where('users.id', '=', 'foo') - .or_where('users.name', '=', 'bar') + builder.select("*").from_("users").join( + JoinClause("contacts") + .where("users.id", "=", "foo") + .or_where("users.name", "=", "bar") ) self.assertEqual( 'SELECT * FROM "users" ' 'INNER JOIN "contacts" ON "users"."id" = ? ' 'OR "users"."name" = ?', - builder.to_sql() + builder.to_sql(), ) - self.assertEqual(['foo', 'bar'], builder.get_bindings()) + self.assertEqual(["foo", "bar"], builder.get_bindings()) self.assertEqual( 'SELECT * FROM "users" ' 'INNER JOIN "contacts" ON "users"."id" = ? ' 'OR "users"."name" = ?', - builder.to_sql() + builder.to_sql(), ) - self.assertEqual(['foo', 'bar'], builder.get_bindings()) + self.assertEqual(["foo", "bar"], builder.get_bindings()) builder = self.get_builder() - builder.select('*').from_('users').left_join( - JoinClause('contacts') - .where('users.id', '=', 'foo') - .or_where('users.name', '=', 'bar') + builder.select("*").from_("users").left_join( + JoinClause("contacts") + .where("users.id", "=", "foo") + .or_where("users.name", "=", "bar") ) self.assertEqual( 'SELECT * FROM "users" ' 'LEFT JOIN "contacts" ON "users"."id" = ? ' 'OR "users"."name" = ?', - builder.to_sql() + builder.to_sql(), ) - self.assertEqual(['foo', 'bar'], builder.get_bindings()) + self.assertEqual(["foo", "bar"], builder.get_bindings()) self.assertEqual( 'SELECT * FROM "users" ' 'LEFT JOIN "contacts" ON "users"."id" = ? ' 'OR "users"."name" = ?', - builder.to_sql() + builder.to_sql(), ) - self.assertEqual(['foo', 'bar'], builder.get_bindings()) + self.assertEqual(["foo", "bar"], builder.get_bindings()) def test_join_where_null(self): builder = self.get_builder() - builder.select('*').from_('users').join( - JoinClause('contacts') - .on('users.id', '=', 'contacts.id') - .where_null('contacts.deleted_at') + builder.select("*").from_("users").join( + JoinClause("contacts") + .on("users.id", "=", "contacts.id") + .where_null("contacts.deleted_at") ) self.assertEqual( 'SELECT * FROM "users" ' 'INNER JOIN "contacts" ' 'ON "users"."id" = "contacts"."id" ' 'AND "contacts"."deleted_at" IS NULL', - builder.to_sql() + builder.to_sql(), ) builder = self.get_builder() - builder.select('*').from_('users').join( - JoinClause('contacts') - .on('users.id', '=', 'contacts.id') - .or_where_null('contacts.deleted_at') + builder.select("*").from_("users").join( + JoinClause("contacts") + .on("users.id", "=", "contacts.id") + .or_where_null("contacts.deleted_at") ) self.assertEqual( 'SELECT * FROM "users" ' 'INNER JOIN "contacts" ' 'ON "users"."id" = "contacts"."id" ' 'OR "contacts"."deleted_at" IS NULL', - builder.to_sql() + builder.to_sql(), ) def test_join_where_not_null(self): builder = self.get_builder() - builder.select('*').from_('users').join( - JoinClause('contacts') - .on('users.id', '=', 'contacts.id') - .where_not_null('contacts.deleted_at') + builder.select("*").from_("users").join( + JoinClause("contacts") + .on("users.id", "=", "contacts.id") + .where_not_null("contacts.deleted_at") ) self.assertEqual( 'SELECT * FROM "users" ' 'INNER JOIN "contacts" ' 'ON "users"."id" = "contacts"."id" ' 'AND "contacts"."deleted_at" IS NOT NULL', - builder.to_sql() + builder.to_sql(), ) builder = self.get_builder() - builder.select('*').from_('users').join( - JoinClause('contacts') - .on('users.id', '=', 'contacts.id') - .or_where_not_null('contacts.deleted_at') + builder.select("*").from_("users").join( + JoinClause("contacts") + .on("users.id", "=", "contacts.id") + .or_where_not_null("contacts.deleted_at") ) self.assertEqual( 'SELECT * FROM "users" ' 'INNER JOIN "contacts" ' 'ON "users"."id" = "contacts"."id" ' 'OR "contacts"."deleted_at" IS NOT NULL', - builder.to_sql() + builder.to_sql(), ) def test_raw_expression_in_select(self): builder = self.get_builder() - builder.select(QueryExpression('substr(foo, 6)')).from_('users') + builder.select(QueryExpression("substr(foo, 6)")).from_("users") self.assertEqual('SELECT substr(foo, 6) FROM "users"', builder.to_sql()) def test_find_return_first_result_by_id(self): builder = self.get_builder() query = 'SELECT * FROM "users" WHERE "id" = ? LIMIT 1' - results = [{ - 'foo': 'bar' - }] + results = [{"foo": "bar"}] builder.get_connection().select.return_value = results - builder.get_processor().process_select = mock.MagicMock(side_effect=lambda builder_, results_: results) - result = builder.from_('users').find(1) - builder.get_connection().select.assert_called_once_with( - query, [1], True + builder.get_processor().process_select = mock.MagicMock( + side_effect=lambda builder_, results_: results ) + result = builder.from_("users").find(1) + builder.get_connection().select.assert_called_once_with(query, [1], True) builder.get_processor().process_select.assert_called_once_with(builder, results) self.assertEqual(result, results[0]) def test_first_return_first_result(self): builder = self.get_builder() query = 'SELECT * FROM "users" WHERE "id" = ? LIMIT 1' - results = [{ - 'foo': 'bar' - }] + results = [{"foo": "bar"}] builder.get_connection().select.return_value = results - builder.get_processor().process_select = mock.MagicMock(side_effect=lambda builder_, results_: results) - result = builder.from_('users').where('id', '=', 1).first() - builder.get_connection().select.assert_called_once_with( - query, [1], True + builder.get_processor().process_select = mock.MagicMock( + side_effect=lambda builder_, results_: results ) + result = builder.from_("users").where("id", "=", 1).first() + builder.get_connection().select.assert_called_once_with(query, [1], True) builder.get_processor().process_select.assert_called_once_with(builder, results) self.assertEqual(result, results[0]) def test_list_methods_gets_list_of_colmun_values(self): builder = self.get_builder() - results = [ - {'foo': 'bar'}, {'foo': 'baz'} - ] + results = [{"foo": "bar"}, {"foo": "baz"}] builder.get_connection().select.return_value = results - builder.get_processor().process_select = mock.MagicMock(side_effect=lambda builder_, results_: results) - result = builder.from_('users').where('id', '=', 1).lists('foo') - self.assertEqual(['bar', 'baz'], result) + builder.get_processor().process_select = mock.MagicMock( + side_effect=lambda builder_, results_: results + ) + result = builder.from_("users").where("id", "=", 1).lists("foo") + self.assertEqual(["bar", "baz"], result) builder = self.get_builder() - results = [ - {'id': 1, 'foo': 'bar'}, {'id': 10, 'foo': 'baz'} - ] + results = [{"id": 1, "foo": "bar"}, {"id": 10, "foo": "baz"}] builder.get_connection().select.return_value = results - builder.get_processor().process_select = mock.MagicMock(side_effect=lambda builder_, results_: results) - result = builder.from_('users').where('id', '=', 1).lists('foo', 'id') - self.assertEqual({1: 'bar', 10: 'baz'}, result) + builder.get_processor().process_select = mock.MagicMock( + side_effect=lambda builder_, results_: results + ) + result = builder.from_("users").where("id", "=", 1).lists("foo", "id") + self.assertEqual({1: "bar", 10: "baz"}, result) def test_implode(self): builder = self.get_builder() - results = [ - {'foo': 'bar'}, {'foo': 'baz'} - ] + results = [{"foo": "bar"}, {"foo": "baz"}] builder.get_connection().select.return_value = results - builder.get_processor().process_select = mock.MagicMock(side_effect=lambda builder_, results_: results) - result = builder.from_('users').where('id', '=', 1).implode('foo') - self.assertEqual('barbaz', result) + builder.get_processor().process_select = mock.MagicMock( + side_effect=lambda builder_, results_: results + ) + result = builder.from_("users").where("id", "=", 1).implode("foo") + self.assertEqual("barbaz", result) builder = self.get_builder() - results = [ - {'foo': 'bar'}, {'foo': 'baz'} - ] + results = [{"foo": "bar"}, {"foo": "baz"}] builder.get_connection().select.return_value = results - builder.get_processor().process_select = mock.MagicMock(side_effect=lambda builder_, results_: results) - result = builder.from_('users').where('id', '=', 1).implode('foo', ',') - self.assertEqual('bar,baz', result) + builder.get_processor().process_select = mock.MagicMock( + side_effect=lambda builder_, results_: results + ) + result = builder.from_("users").where("id", "=", 1).implode("foo", ",") + self.assertEqual("bar,baz", result) def test_pluck_return_single_column(self): builder = self.get_builder() query = 'SELECT "foo" FROM "users" WHERE "id" = ? LIMIT 1' - results = [{'foo': 'bar'}] + results = [{"foo": "bar"}] builder.get_connection().select.return_value = results - builder.get_processor().process_select = mock.MagicMock(side_effect=lambda builder_, results_: results) - result = builder.from_('users').where('id', '=', 1).pluck('foo') - builder.get_connection().select.assert_called_once_with( - query, [1], True + builder.get_processor().process_select = mock.MagicMock( + side_effect=lambda builder_, results_: results ) + result = builder.from_("users").where("id", "=", 1).pluck("foo") + builder.get_connection().select.assert_called_once_with(query, [1], True) builder.get_processor().process_select.assert_called_once_with(builder, results) - self.assertEqual('bar', result) + self.assertEqual("bar", result) def test_aggegate_functions(self): builder = self.get_builder() query = 'SELECT COUNT(*) AS aggregate FROM "users"' - results = [{'aggregate': 1}] + results = [{"aggregate": 1}] builder.get_connection().select.return_value = results - builder.get_processor().process_select = mock.MagicMock(side_effect=lambda builder_, results_: results) - result = builder.from_('users').count() - builder.get_connection().select.assert_called_once_with( - query, [], True + builder.get_processor().process_select = mock.MagicMock( + side_effect=lambda builder_, results_: results ) + result = builder.from_("users").count() + builder.get_connection().select.assert_called_once_with(query, [], True) builder.get_processor().process_select.assert_called_once_with(builder, results) self.assertEqual(1, result) builder = self.get_builder() query = 'SELECT COUNT(*) AS aggregate FROM "users" LIMIT 1' - results = [{'aggregate': 1}] + results = [{"aggregate": 1}] builder.get_connection().select.return_value = results - builder.get_processor().process_select = mock.MagicMock(side_effect=lambda builder_, results_: results_) - result = builder.from_('users').exists() - builder.get_connection().select.assert_called_once_with( - query, [], True + builder.get_processor().process_select = mock.MagicMock( + side_effect=lambda builder_, results_: results_ ) + result = builder.from_("users").exists() + builder.get_connection().select.assert_called_once_with(query, [], True) builder.get_processor().process_select.assert_called_once_with(builder, results) self.assertTrue(result) builder = self.get_builder() query = 'SELECT MAX("id") AS aggregate FROM "users"' - results = [{'aggregate': 1}] + results = [{"aggregate": 1}] builder.get_connection().select.return_value = results - builder.get_processor().process_select = mock.MagicMock(side_effect=lambda builder_, results_: results_) - result = builder.from_('users').max('id') - builder.get_connection().select.assert_called_once_with( - query, [], True + builder.get_processor().process_select = mock.MagicMock( + side_effect=lambda builder_, results_: results_ ) + result = builder.from_("users").max("id") + builder.get_connection().select.assert_called_once_with(query, [], True) builder.get_processor().process_select.assert_called_once_with(builder, results) self.assertEqual(1, result) builder = self.get_builder() query = 'SELECT MIN("id") AS aggregate FROM "users"' - results = [{'aggregate': 1}] + results = [{"aggregate": 1}] builder.get_connection().select.return_value = results - builder.get_processor().process_select = mock.MagicMock(side_effect=lambda builder_, results_: results_) - result = builder.from_('users').min('id') - builder.get_connection().select.assert_called_once_with( - query, [], True + builder.get_processor().process_select = mock.MagicMock( + side_effect=lambda builder_, results_: results_ ) + result = builder.from_("users").min("id") + builder.get_connection().select.assert_called_once_with(query, [], True) builder.get_processor().process_select.assert_called_once_with(builder, results) self.assertEqual(1, result) builder = self.get_builder() query = 'SELECT SUM("id") AS aggregate FROM "users"' - results = [{'aggregate': 1}] + results = [{"aggregate": 1}] builder.get_connection().select.return_value = results - builder.get_processor().process_select = mock.MagicMock(side_effect=lambda builder_, results_: results_) - result = builder.from_('users').sum('id') - builder.get_connection().select.assert_called_once_with( - query, [], True + builder.get_processor().process_select = mock.MagicMock( + side_effect=lambda builder_, results_: results_ ) + result = builder.from_("users").sum("id") + builder.get_connection().select.assert_called_once_with(query, [], True) builder.get_processor().process_select.assert_called_once_with(builder, results) self.assertEqual(1, result) builder = self.get_builder() query = 'SELECT AVG("id") AS aggregate FROM "users"' - results = [{'aggregate': 1}] + results = [{"aggregate": 1}] builder.get_connection().select.return_value = results - builder.get_processor().process_select = mock.MagicMock(side_effect=lambda builder_, results_: results_) - result = builder.from_('users').avg('id') - builder.get_connection().select.assert_called_once_with( - query, [], True + builder.get_processor().process_select = mock.MagicMock( + side_effect=lambda builder_, results_: results_ + ) + result = builder.from_("users").avg("id") + builder.get_connection().select.assert_called_once_with(query, [], True) + builder.get_processor().process_select.assert_called_once_with(builder, results) + self.assertEqual(1, result) + + def test_distinct_count_with_column(self): + builder = self.get_builder() + query = 'SELECT COUNT(DISTINCT "id") AS aggregate FROM "users"' + results = [{"aggregate": 1}] + builder.get_connection().select.return_value = results + builder.get_processor().process_select = mock.MagicMock( + side_effect=lambda builder_, results_: results ) + result = builder.from_("users").distinct().count("id") + builder.get_connection().select.assert_called_once_with(query, [], True) + builder.get_processor().process_select.assert_called_once_with(builder, results) + self.assertEqual(1, result) + + def test_distinct_count_with_select(self): + builder = self.get_builder() + query = 'SELECT COUNT(DISTINCT "id") AS aggregate FROM "users"' + results = [{"aggregate": 1}] + builder.get_connection().select.return_value = results + builder.get_processor().process_select = mock.MagicMock( + side_effect=lambda builder_, results_: results + ) + result = builder.from_("users").distinct().select("id").count() + builder.get_connection().select.assert_called_once_with(query, [], True) builder.get_processor().process_select.assert_called_once_with(builder, results) self.assertEqual(1, result) def test_aggregate_reset_followed_by_get(self): builder = self.get_builder() query = 'SELECT COUNT(*) AS aggregate FROM "users"' - builder.get_connection().select.return_value = [{'aggregate': 1}] - builder.get_processor().process_select = mock.MagicMock(side_effect=lambda builder_, results_: results_) - builder.from_('users').select('column1', 'column2') + builder.get_connection().select.return_value = [{"aggregate": 1}] + builder.get_processor().process_select = mock.MagicMock( + side_effect=lambda builder_, results_: results_ + ) + builder.from_("users").select("column1", "column2") count = builder.count() - builder.get_connection().select.assert_called_once_with( - query, [], True + builder.get_connection().select.assert_called_once_with(query, [], True) + builder.get_processor().process_select.assert_called_once_with( + builder, [{"aggregate": 1}] ) - builder.get_processor().process_select.assert_called_once_with(builder, [{'aggregate': 1}]) self.assertEqual(1, count) builder.get_connection().select.reset_mock() builder.get_processor().process_select.reset_mock() query = 'SELECT SUM("id") AS aggregate FROM "users"' - builder.get_connection().select.return_value = [{'aggregate': 2}] - sum_ = builder.sum('id') - builder.get_connection().select.assert_called_once_with( - query, [], True + builder.get_connection().select.return_value = [{"aggregate": 2}] + sum_ = builder.sum("id") + builder.get_connection().select.assert_called_once_with(query, [], True) + builder.get_processor().process_select.assert_called_once_with( + builder, [{"aggregate": 2}] ) - builder.get_processor().process_select.assert_called_once_with(builder, [{'aggregate': 2}]) self.assertEqual(2, sum_) builder.get_connection().select.reset_mock() builder.get_processor().process_select.reset_mock() query = 'SELECT "column1", "column2" FROM "users"' - builder.get_connection().select.return_value = [{'column1': 'foo', 'column2': 'bar'}] + builder.get_connection().select.return_value = [ + {"column1": "foo", "column2": "bar"} + ] result = builder.get() - builder.get_connection().select.assert_called_once_with( - query, [], True + builder.get_connection().select.assert_called_once_with(query, [], True) + builder.get_processor().process_select.assert_called_once_with( + builder, [{"column1": "foo", "column2": "bar"}] ) - builder.get_processor().process_select.assert_called_once_with(builder, [{'column1': 'foo', 'column2': 'bar'}]) - self.assertEqual([{'column1': 'foo', 'column2': 'bar'}], result) + self.assertEqual([{"column1": "foo", "column2": "bar"}], result) def test_aggregate_reset_followed_by_select_get(self): builder = self.get_builder() query = 'SELECT COUNT("column1") AS aggregate FROM "users"' - builder.get_connection().select.return_value = [{'aggregate': 1}] - builder.get_processor().process_select = mock.MagicMock(side_effect=lambda builder_, results_: results_) - builder.from_('users') - count = builder.count('column1') - builder.get_connection().select.assert_called_once_with( - query, [], True + builder.get_connection().select.return_value = [{"aggregate": 1}] + builder.get_processor().process_select = mock.MagicMock( + side_effect=lambda builder_, results_: results_ + ) + builder.from_("users") + count = builder.count("column1") + builder.get_connection().select.assert_called_once_with(query, [], True) + builder.get_processor().process_select.assert_called_once_with( + builder, [{"aggregate": 1}] ) - builder.get_processor().process_select.assert_called_once_with(builder, [{'aggregate': 1}]) self.assertEqual(1, count) builder.get_connection().select.reset_mock() builder.get_processor().process_select.reset_mock() query = 'SELECT "column2", "column3" FROM "users"' - builder.get_connection().select.return_value = [{'column2': 'foo', 'column3': 'bar'}] - result = builder.select('column2', 'column3').get() - builder.get_connection().select.assert_called_once_with( - query, [], True + builder.get_connection().select.return_value = [ + {"column2": "foo", "column3": "bar"} + ] + result = builder.select("column2", "column3").get() + builder.get_connection().select.assert_called_once_with(query, [], True) + builder.get_processor().process_select.assert_called_once_with( + builder, [{"column2": "foo", "column3": "bar"}] ) - builder.get_processor().process_select.assert_called_once_with(builder, [{'column2': 'foo', 'column3': 'bar'}]) - self.assertEqual([{'column2': 'foo', 'column3': 'bar'}], result) + self.assertEqual([{"column2": "foo", "column3": "bar"}], result) def test_aggregate_reset_followed_by_get_with_columns(self): builder = self.get_builder() query = 'SELECT COUNT("column1") AS aggregate FROM "users"' - builder.get_connection().select.return_value = [{'aggregate': 1}] - builder.get_processor().process_select = mock.MagicMock(side_effect=lambda builder_, results_: results_) - builder.from_('users') - count = builder.count('column1') - builder.get_connection().select.assert_called_once_with( - query, [], True + builder.get_connection().select.return_value = [{"aggregate": 1}] + builder.get_processor().process_select = mock.MagicMock( + side_effect=lambda builder_, results_: results_ + ) + builder.from_("users") + count = builder.count("column1") + builder.get_connection().select.assert_called_once_with(query, [], True) + builder.get_processor().process_select.assert_called_once_with( + builder, [{"aggregate": 1}] ) - builder.get_processor().process_select.assert_called_once_with(builder, [{'aggregate': 1}]) self.assertEqual(1, count) builder.get_connection().select.reset_mock() builder.get_processor().process_select.reset_mock() query = 'SELECT "column2", "column3" FROM "users"' - builder.get_connection().select.return_value = [{'column2': 'foo', 'column3': 'bar'}] - result = builder.get(['column2', 'column3']) - builder.get_connection().select.assert_called_once_with( - query, [], True + builder.get_connection().select.return_value = [ + {"column2": "foo", "column3": "bar"} + ] + result = builder.get(["column2", "column3"]) + builder.get_connection().select.assert_called_once_with(query, [], True) + builder.get_processor().process_select.assert_called_once_with( + builder, [{"column2": "foo", "column3": "bar"}] ) - builder.get_processor().process_select.assert_called_once_with(builder, [{'column2': 'foo', 'column3': 'bar'}]) - self.assertEqual([{'column2': 'foo', 'column3': 'bar'}], result) + self.assertEqual([{"column2": "foo", "column3": "bar"}], result) def test_insert_method(self): builder = self.get_builder() query = 'INSERT INTO "users" ("email") VALUES (?)' builder.get_connection().insert.return_value = True - result = builder.from_('users').insert({'email': 'foo'}) - builder.get_connection().insert.assert_called_once_with( - query, ['foo'] - ) + result = builder.from_("users").insert({"email": "foo"}) + builder.get_connection().insert.assert_called_once_with(query, ["foo"]) self.assertTrue(result) def test_insert_method_with_keyword_arguments(self): builder = self.get_builder() query = 'INSERT INTO "users" ("email") VALUES (?)' builder.get_connection().insert.return_value = True - result = builder.from_('users').insert({'email': 'foo'}) - builder.get_connection().insert.assert_called_once_with( - query, ['foo'] - ) + result = builder.from_("users").insert({"email": "foo"}) + builder.get_connection().insert.assert_called_once_with(query, ["foo"]) self.assertTrue(result) def test_sqlite_multiple_insert(self): builder = self.get_sqlite_builder() - query = 'INSERT INTO "users" ("email", "name") ' \ - 'SELECT ? AS "email", ? AS "name" UNION ALL SELECT ? AS "email", ? AS "name"' + query = ( + 'INSERT INTO "users" ("email", "name") ' + 'SELECT ? AS "email", ? AS "name" UNION ALL SELECT ? AS "email", ? AS "name"' + ) builder.get_connection().insert.return_value = True - result = builder.from_('users').insert([ - {'email': 'foo', 'name': 'john'}, - {'email': 'bar', 'name': 'jane'} - ]) + result = builder.from_("users").insert( + [{"email": "foo", "name": "john"}, {"email": "bar", "name": "jane"}] + ) builder.get_connection().insert.assert_called_once_with( - query, ['foo', 'john', 'bar', 'jane'] + query, ["foo", "john", "bar", "jane"] ) self.assertTrue(result) def test_insert_get_id_method(self): builder = self.get_builder() builder.get_processor().process_insert_get_id.return_value = 1 - result = builder.from_('users').insert_get_id({ - 'email': 'foo', - 'bar': QueryExpression('bar') - }) + result = builder.from_("users").insert_get_id( + {"email": "foo", "bar": QueryExpression("bar")} + ) builder.get_processor().process_insert_get_id.assert_called_once_with( - builder, 'INSERT INTO "users" ("bar", "email") VALUES (bar, ?)', ['foo'], None + builder, + 'INSERT INTO "users" ("bar", "email") VALUES (bar, ?)', + ["foo"], + None, ) self.assertEqual(1, result) def test_insert_get_id_with_sequence(self): builder = self.get_builder() builder.get_processor().process_insert_get_id.return_value = 1 - result = builder.from_('users').insert_get_id({ - 'email': 'foo', - 'bar': QueryExpression('bar') - }, 'id') + result = builder.from_("users").insert_get_id( + {"email": "foo", "bar": QueryExpression("bar")}, "id" + ) builder.get_processor().process_insert_get_id.assert_called_once_with( - builder, 'INSERT INTO "users" ("bar", "email") VALUES (bar, ?)', ['foo'], 'id' + builder, + 'INSERT INTO "users" ("bar", "email") VALUES (bar, ?)', + ["foo"], + "id", ) self.assertEqual(1, result) def test_insert_get_id_respects_raw_bindings(self): builder = self.get_builder() builder.get_processor().process_insert_get_id.return_value = 1 - result = builder.from_('users').insert_get_id({ - 'email': QueryExpression('CURRENT_TIMESTAMP'), - }) + result = builder.from_("users").insert_get_id( + {"email": QueryExpression("CURRENT_TIMESTAMP")} + ) builder.get_processor().process_insert_get_id.assert_called_once_with( - builder, 'INSERT INTO "users" ("email") VALUES (CURRENT_TIMESTAMP)', [], None + builder, + 'INSERT INTO "users" ("email") VALUES (CURRENT_TIMESTAMP)', + [], + None, ) self.assertEqual(1, result) @@ -1230,225 +1258,258 @@ def test_update(self): builder = self.get_builder() query = 'UPDATE "users" SET "email" = ?, "name" = ? WHERE "id" = ?' builder.get_connection().update.return_value = 1 - result = builder.from_('users').where('id', '=', 1).update(email='foo', name='bar') - builder.get_connection().update.assert_called_with( - query, ['foo', 'bar', 1] + result = ( + builder.from_("users").where("id", "=", 1).update(email="foo", name="bar") ) + builder.get_connection().update.assert_called_with(query, ["foo", "bar", 1]) self.assertEqual(1, result) builder = self.get_mysql_builder() marker = builder.get_grammar().get_marker() - query = 'UPDATE `users` SET `email` = %s, `name` = %s WHERE `id` = %s' % (marker, marker, marker) + query = "UPDATE `users` SET `email` = %s, `name` = %s WHERE `id` = %s" % ( + marker, + marker, + marker, + ) builder.get_connection().update.return_value = 1 - result = builder.from_('users').where('id', '=', 1).update(email='foo', name='bar') - builder.get_connection().update.assert_called_with( - query, ['foo', 'bar', 1] + result = ( + builder.from_("users").where("id", "=", 1).update(email="foo", name="bar") ) + builder.get_connection().update.assert_called_with(query, ["foo", "bar", 1]) self.assertEqual(1, result) def test_update_with_dictionaries(self): builder = self.get_builder() query = 'UPDATE "users" SET "email" = ?, "name" = ? WHERE "id" = ?' builder.get_connection().update.return_value = 1 - result = builder.from_('users').where('id', '=', 1).update({'email': 'foo', 'name': 'bar'}) - builder.get_connection().update.assert_called_with( - query, ['foo', 'bar', 1] + result = ( + builder.from_("users") + .where("id", "=", 1) + .update({"email": "foo", "name": "bar"}) ) + builder.get_connection().update.assert_called_with(query, ["foo", "bar", 1]) self.assertEqual(1, result) builder = self.get_builder() query = 'UPDATE "users" SET "email" = ?, "name" = ? WHERE "id" = ?' builder.get_connection().update.return_value = 1 - result = builder.from_('users').where('id', '=', 1).update({'email': 'foo'}, name='bar') - builder.get_connection().update.assert_called_with( - query, ['foo', 'bar', 1] + result = ( + builder.from_("users") + .where("id", "=", 1) + .update({"email": "foo"}, name="bar") ) + builder.get_connection().update.assert_called_with(query, ["foo", "bar", 1]) self.assertEqual(1, result) + def test_where_date(self): + builder = self.get_sqlite_builder() + builder.where_date("date", "=", "10-20-2018") + + self.assertEqual( + builder.to_sql(), + 'SELECT * FROM "" WHERE strftime(\'%Y-%m-%d\', "date") = ?', + ) + def test_update_with_joins(self): builder = self.get_builder() - query = 'UPDATE "users" ' \ - 'INNER JOIN "orders" ON "users"."id" = "orders"."user_id" ' \ - 'SET "email" = ?, "name" = ? WHERE "id" = ?' + query = ( + 'UPDATE "users" ' + 'INNER JOIN "orders" ON "users"."id" = "orders"."user_id" ' + 'SET "email" = ?, "name" = ? WHERE "id" = ?' + ) builder.get_connection().update.return_value = 1 - result = builder.from_('users')\ - .join('orders', 'users.id', '=', 'orders.user_id')\ - .where('id', '=', 1)\ - .update(email='foo', name='bar') - builder.get_connection().update.assert_called_with( - query, ['foo', 'bar', 1] + result = ( + builder.from_("users") + .join("orders", "users.id", "=", "orders.user_id") + .where("id", "=", 1) + .update(email="foo", name="bar") ) + builder.get_connection().update.assert_called_with(query, ["foo", "bar", 1]) self.assertEqual(1, result) def test_update_on_postgres(self): builder = self.get_postgres_builder() marker = builder.get_grammar().get_marker() - query = 'UPDATE "users" SET "email" = %s, "name" = %s WHERE "id" = %s' % (marker, marker, marker) + query = 'UPDATE "users" SET "email" = %s, "name" = %s WHERE "id" = %s' % ( + marker, + marker, + marker, + ) builder.get_connection().update.return_value = 1 - result = builder.from_('users').where('id', '=', 1).update(email='foo', name='bar') - builder.get_connection().update.assert_called_with( - query, ['foo', 'bar', 1] + result = ( + builder.from_("users").where("id", "=", 1).update(email="foo", name="bar") ) + builder.get_connection().update.assert_called_with(query, ["foo", "bar", 1]) self.assertEqual(1, result) def test_update_with_joins_on_postgres(self): builder = self.get_postgres_builder() marker = builder.get_grammar().get_marker() - query = 'UPDATE "users" ' \ - 'SET "email" = %s, "name" = %s ' \ - 'FROM "orders" WHERE "id" = %s AND "users"."id" = "orders"."user_id"'\ - % (marker, marker, marker) + query = ( + 'UPDATE "users" ' + 'SET "email" = %s, "name" = %s ' + 'FROM "orders" WHERE "id" = %s AND "users"."id" = "orders"."user_id"' + % (marker, marker, marker) + ) builder.get_connection().update.return_value = 1 - result = builder.from_('users')\ - .join('orders', 'users.id', '=', 'orders.user_id')\ - .where('id', '=', 1)\ - .update(email='foo', name='bar') + result = ( + builder.from_("users") + .join("orders", "users.id", "=", "orders.user_id") + .where("id", "=", 1) + .update(email="foo", name="bar") + ) builder.get_connection().update.assert_called_once_with( - query, ['foo', 'bar', 1] + query, ["foo", "bar", 1] ) self.assertEqual(1, result) def test_update_respects_raw(self): builder = self.get_builder() - marker = '?' - query = 'UPDATE "users" SET "email" = foo, "name" = %s WHERE "id" = %s' % (marker, marker) + marker = "?" + query = 'UPDATE "users" SET "email" = foo, "name" = %s WHERE "id" = %s' % ( + marker, + marker, + ) builder.get_connection().update.return_value = 1 - result = builder.from_('users').where('id', '=', 1).update(email=QueryExpression('foo'), name='bar') - builder.get_connection().update.assert_called_once_with( - query, ['bar', 1] + result = ( + builder.from_("users") + .where("id", "=", 1) + .update(email=QueryExpression("foo"), name="bar") ) + builder.get_connection().update.assert_called_once_with(query, ["bar", 1]) self.assertEqual(1, result) def test_delete(self): builder = self.get_builder() query = 'DELETE FROM "users" WHERE "email" = ?' builder.get_connection().delete.return_value = 1 - result = builder.from_('users').where('email', '=', 'foo').delete() - builder.get_connection().delete.assert_called_once_with( - query, ['foo'] - ) + result = builder.from_("users").where("email", "=", "foo").delete() + builder.get_connection().delete.assert_called_once_with(query, ["foo"]) self.assertEqual(1, result) builder = self.get_builder() query = 'DELETE FROM "users" WHERE "id" = ?' builder.get_connection().delete.return_value = 1 - result = builder.from_('users').delete(1) - builder.get_connection().delete.assert_called_once_with( - query, [1] - ) + result = builder.from_("users").delete(1) + builder.get_connection().delete.assert_called_once_with(query, [1]) self.assertEqual(1, result) def test_delete_with_join(self): builder = self.get_mysql_builder() marker = builder.get_grammar().get_marker() - query = 'DELETE `users` FROM `users` ' \ - 'INNER JOIN `contacts` ON `users`.`id` = `contacts`.`id` WHERE `email` = %s' % marker + query = ( + "DELETE `users` FROM `users` " + "INNER JOIN `contacts` ON `users`.`id` = `contacts`.`id` WHERE `email` = %s" + % marker + ) builder.get_connection().delete.return_value = 1 - result = builder.from_('users')\ - .join('contacts', 'users.id', '=', 'contacts.id')\ - .where('email', '=', 'foo')\ + result = ( + builder.from_("users") + .join("contacts", "users.id", "=", "contacts.id") + .where("email", "=", "foo") .delete() - builder.get_connection().delete.assert_called_once_with( - query, ['foo'] ) + builder.get_connection().delete.assert_called_once_with(query, ["foo"]) self.assertEqual(1, result) builder = self.get_mysql_builder() marker = builder.get_grammar().get_marker() - query = 'DELETE `users` FROM `users` ' \ - 'INNER JOIN `contacts` ON `users`.`id` = `contacts`.`id` WHERE `id` = %s' % marker + query = ( + "DELETE `users` FROM `users` " + "INNER JOIN `contacts` ON `users`.`id` = `contacts`.`id` WHERE `id` = %s" + % marker + ) builder.get_connection().delete.return_value = 1 - result = builder.from_('users')\ - .join('contacts', 'users.id', '=', 'contacts.id')\ + result = ( + builder.from_("users") + .join("contacts", "users.id", "=", "contacts.id") .delete(1) - builder.get_connection().delete.assert_called_once_with( - query, [1] ) + builder.get_connection().delete.assert_called_once_with(query, [1]) self.assertEqual(1, result) def test_truncate(self): builder = self.get_builder() query = 'TRUNCATE "users"' - builder.from_('users').truncate() - builder.get_connection().statement.assert_called_once_with( - query, [] - ) + builder.from_("users").truncate() + builder.get_connection().statement.assert_called_once_with(query, []) builder = self.get_sqlite_builder() - builder.from_('users') - self.assertEqual({ - 'DELETE FROM sqlite_sequence WHERE name = ?': ['users'], - 'DELETE FROM "users"': [] - }, builder.get_grammar().compile_truncate(builder)) + builder.from_("users") + self.assertEqual( + { + "DELETE FROM sqlite_sequence WHERE name = ?": ["users"], + 'DELETE FROM "users"': [], + }, + builder.get_grammar().compile_truncate(builder), + ) def test_postgres_insert_get_id(self): builder = self.get_postgres_builder() marker = builder.get_grammar().get_marker() query = 'INSERT INTO "users" ("email") VALUES (%s) RETURNING "id"' % marker builder.get_processor().process_insert_get_id.return_value = 1 - result = builder.from_('users').insert_get_id({'email': 'foo'}, 'id') + result = builder.from_("users").insert_get_id({"email": "foo"}, "id") builder.get_processor().process_insert_get_id.assert_called_once_with( - builder, query, ['foo'], 'id' + builder, query, ["foo"], "id" ) self.assertEqual(1, result) def test_mysql_wrapping(self): builder = self.get_mysql_builder() - builder.select('*').from_('users') - self.assertEqual( - 'SELECT * FROM `users`', - builder.to_sql() - ) + builder.select("*").from_("users") + self.assertEqual("SELECT * FROM `users`", builder.to_sql()) def test_merge_wheres_can_merge_wheres_and_bindings(self): builder = self.get_builder() - builder.wheres = ['foo'] - builder.merge_wheres(['wheres'], ['foo', 'bar']) - self.assertEqual(['foo', 'wheres'], builder.wheres) - self.assertEqual(['foo', 'bar'], builder.get_bindings()) + builder.wheres = ["foo"] + builder.merge_wheres(["wheres"], ["foo", "bar"]) + self.assertEqual(["foo", "wheres"], builder.wheres) + self.assertEqual(["foo", "bar"], builder.get_bindings()) def test_where_with_null_second_parameter(self): builder = self.get_builder() - builder.select('*').from_('users').where('foo', None) - self.assertEqual( - 'SELECT * FROM "users" WHERE "foo" IS NULL', - builder.to_sql() - ) + builder.select("*").from_("users").where("foo", None) + self.assertEqual('SELECT * FROM "users" WHERE "foo" IS NULL', builder.to_sql()) def test_dynamic_where(self): - method = 'where_foo_bar_and_baz_or_boom' - parameters = ['john', 'jane', 'bam'] + method = "where_foo_bar_and_baz_or_boom" + parameters = ["john", "jane", "bam"] builder = self.get_builder() builder.where = mock.MagicMock(return_value=builder) getattr(builder, method)(*parameters) - builder.where.assert_has_calls([ - mock.call('foo_bar', '=', parameters[0], 'and'), - mock.call('baz', '=', parameters[1], 'and'), - mock.call('boom', '=', parameters[2], 'or') - ]) + builder.where.assert_has_calls( + [ + mock.call("foo_bar", "=", parameters[0], "and"), + mock.call("baz", "=", parameters[1], "and"), + mock.call("boom", "=", parameters[2], "or"), + ] + ) def test_dynamic_where_is_not_greedy(self): - method = 'where_ios_version_and_android_version_or_orientation' - parameters = ['6.1', '4.2', 'Vertical'] + method = "where_ios_version_and_android_version_or_orientation" + parameters = ["6.1", "4.2", "Vertical"] builder = self.get_builder() builder.where = mock.MagicMock(return_value=builder) getattr(builder, method)(*parameters) - builder.where.assert_has_calls([ - mock.call('ios_version', '=', parameters[0], 'and'), - mock.call('android_version', '=', parameters[1], 'and'), - mock.call('orientation', '=', parameters[2], 'or') - ]) + builder.where.assert_has_calls( + [ + mock.call("ios_version", "=", parameters[0], "and"), + mock.call("android_version", "=", parameters[1], "and"), + mock.call("orientation", "=", parameters[2], "or"), + ] + ) def test_call_triggers_dynamic_where(self): builder = self.get_builder() - self.assertEqual(builder, builder.where_foo_and_bar('baz', 'boom')) + self.assertEqual(builder, builder.where_foo_and_bar("baz", "boom")) self.assertEqual(2, len(builder.wheres)) def test_builder_raises_exception_with_undefined_method(self): @@ -1456,171 +1517,164 @@ def test_builder_raises_exception_with_undefined_method(self): try: builder.do_not_exist() - self.fail('Builder did not raise and AttributeError exception') + self.fail("Builder did not raise and AttributeError exception") except AttributeError: self.assertTrue(True) def test_mysql_lock(self): builder = self.get_mysql_builder() marker = builder.get_grammar().get_marker() - builder.select('*').from_('foo').where('bar', '=', 'baz').lock() + builder.select("*").from_("foo").where("bar", "=", "baz").lock() self.assertEqual( - 'SELECT * FROM `foo` WHERE `bar` = %s FOR UPDATE' % marker, - builder.to_sql() + "SELECT * FROM `foo` WHERE `bar` = %s FOR UPDATE" % marker, builder.to_sql() ) - self.assertEqual(['baz'], builder.get_bindings()) + self.assertEqual(["baz"], builder.get_bindings()) builder = self.get_mysql_builder() marker = builder.get_grammar().get_marker() - builder.select('*').from_('foo').where('bar', '=', 'baz').lock(False) + builder.select("*").from_("foo").where("bar", "=", "baz").lock(False) self.assertEqual( - 'SELECT * FROM `foo` WHERE `bar` = %s LOCK IN SHARE MODE' % marker, - builder.to_sql() + "SELECT * FROM `foo` WHERE `bar` = %s LOCK IN SHARE MODE" % marker, + builder.to_sql(), ) - self.assertEqual(['baz'], builder.get_bindings()) + self.assertEqual(["baz"], builder.get_bindings()) def test_postgres_lock(self): builder = self.get_postgres_builder() marker = builder.get_grammar().get_marker() - builder.select('*').from_('foo').where('bar', '=', 'baz').lock() + builder.select("*").from_("foo").where("bar", "=", "baz").lock() self.assertEqual( - 'SELECT * FROM "foo" WHERE "bar" = %s FOR UPDATE' % marker, - builder.to_sql() + 'SELECT * FROM "foo" WHERE "bar" = %s FOR UPDATE' % marker, builder.to_sql() ) - self.assertEqual(['baz'], builder.get_bindings()) + self.assertEqual(["baz"], builder.get_bindings()) builder = self.get_postgres_builder() marker = builder.get_grammar().get_marker() - builder.select('*').from_('foo').where('bar', '=', 'baz').lock(False) + builder.select("*").from_("foo").where("bar", "=", "baz").lock(False) self.assertEqual( - 'SELECT * FROM "foo" WHERE "bar" = %s FOR SHARE' % marker, - builder.to_sql() + 'SELECT * FROM "foo" WHERE "bar" = %s FOR SHARE' % marker, builder.to_sql() ) - self.assertEqual(['baz'], builder.get_bindings()) + self.assertEqual(["baz"], builder.get_bindings()) def test_binding_order(self): - expected_sql = 'SELECT * FROM "users" ' \ - 'INNER JOIN "othertable" ON "bar" = ? ' \ - 'WHERE "registered" = ? ' \ - 'GROUP BY "city" ' \ - 'HAVING "population" > ? ' \ - 'ORDER BY match ("foo") against(?)' - expected_bindings = ['foo', True, 3, 'bar'] - - builder = self.get_builder() - builder.select('*').from_('users')\ - .order_by_raw('match ("foo") against(?)', ['bar'])\ - .where('registered', True)\ - .group_by('city')\ - .having('population', '>', 3)\ - .join(JoinClause('othertable').where('bar', '=', 'foo')) + expected_sql = ( + 'SELECT * FROM "users" ' + 'INNER JOIN "othertable" ON "bar" = ? ' + 'WHERE "registered" = ? ' + 'GROUP BY "city" ' + 'HAVING "population" > ? ' + 'ORDER BY match ("foo") against(?)' + ) + expected_bindings = ["foo", True, 3, "bar"] + + builder = self.get_builder() + builder.select("*").from_("users").order_by_raw( + 'match ("foo") against(?)', ["bar"] + ).where("registered", True).group_by("city").having("population", ">", 3).join( + JoinClause("othertable").where("bar", "=", "foo") + ) self.assertEqual(expected_sql, builder.to_sql()) self.assertEqual(expected_bindings, builder.get_bindings()) def test_add_binding_with_list_merges_bindings(self): builder = self.get_builder() - builder.add_binding(['foo', 'bar']) - builder.add_binding(['baz']) - self.assertEqual(['foo', 'bar', 'baz'], builder.get_bindings()) + builder.add_binding(["foo", "bar"]) + builder.add_binding(["baz"]) + self.assertEqual(["foo", "bar", "baz"], builder.get_bindings()) def test_add_binding_with_list_merges_bindings_in_correct_order(self): builder = self.get_builder() - builder.add_binding(['bar', 'baz'], 'having') - builder.add_binding(['foo'], 'where') - self.assertEqual(['foo', 'bar', 'baz'], builder.get_bindings()) + builder.add_binding(["bar", "baz"], "having") + builder.add_binding(["foo"], "where") + self.assertEqual(["foo", "bar", "baz"], builder.get_bindings()) def test_merge_builders(self): builder = self.get_builder() - builder.add_binding('foo', 'where') - builder.add_binding('baz', 'having') + builder.add_binding("foo", "where") + builder.add_binding("baz", "having") other_builder = self.get_builder() - other_builder.add_binding('bar', 'where') + other_builder.add_binding("bar", "where") builder.merge_bindings(other_builder) - self.assertEqual(['foo', 'bar', 'baz'], builder.get_bindings()) + self.assertEqual(["foo", "bar", "baz"], builder.get_bindings()) def test_sub_select(self): builder = self.get_builder() marker = builder.get_grammar().get_marker() - expected_sql = 'SELECT "foo", "bar", (SELECT "baz" FROM "two" WHERE "subkey" = %s) AS "sub" ' \ - 'FROM "one" WHERE "key" = %s' % (marker, marker) - expected_bindings = ['subval', 'val'] + expected_sql = ( + 'SELECT "foo", "bar", (SELECT "baz" FROM "two" WHERE "subkey" = %s) AS "sub" ' + 'FROM "one" WHERE "key" = %s' % (marker, marker) + ) + expected_bindings = ["subval", "val"] - builder.from_('one').select('foo', 'bar').where('key', '=', 'val') - builder.select_sub(builder.new_query().from_('two').select('baz').where('subkey', '=', 'subval'), 'sub') + builder.from_("one").select("foo", "bar").where("key", "=", "val") + builder.select_sub( + builder.new_query() + .from_("two") + .select("baz") + .where("subkey", "=", "subval"), + "sub", + ) self.assertEqual(expected_sql, builder.to_sql()) self.assertEqual(expected_bindings, builder.get_bindings()) def test_chunk(self): builder = self.get_builder() - results = [ - {'foo': 'bar'}, - {'foo': 'baz'}, - {'foo': 'bam'}, - {'foo': 'boom'} - ] + results = [{"foo": "bar"}, {"foo": "baz"}, {"foo": "bam"}, {"foo": "boom"}] def select(query, bindings, _): - index = int(re.search('OFFSET (\d+)', query).group(1)) - limit = int(re.search('LIMIT (\d+)', query).group(1)) + index = int(re.search("OFFSET (\d+)", query).group(1)) + limit = int(re.search("LIMIT (\d+)", query).group(1)) if index >= len(results): return [] - return results[index:index + limit] + return results[index : index + limit] builder.get_connection().select.side_effect = select - builder.get_processor().process_select = mock.MagicMock(side_effect=lambda builder_, results_: results_) + builder.get_processor().process_select = mock.MagicMock( + side_effect=lambda builder_, results_: results_ + ) i = 0 - for users in builder.from_('users').chunk(1): + for users in builder.from_("users").chunk(1): self.assertEqual(users[0], results[i]) i += 1 builder = self.get_builder() - results = [ - {'foo': 'bar'}, - {'foo': 'baz'}, - {'foo': 'bam'}, - {'foo': 'boom'} - ] + results = [{"foo": "bar"}, {"foo": "baz"}, {"foo": "bam"}, {"foo": "boom"}] builder.get_connection().select.side_effect = select - builder.get_processor().process_select = mock.MagicMock(side_effect=lambda builder_, results_: results_) + builder.get_processor().process_select = mock.MagicMock( + side_effect=lambda builder_, results_: results_ + ) - for users in builder.from_('users').chunk(2): + for users in builder.from_("users").chunk(2): self.assertEqual(2, len(users)) def test_not_specifying_columns_sects_all(self): builder = self.get_builder() - builder.from_('users') + builder.from_("users") - self.assertEqual( - 'SELECT * FROM "users"', - builder.to_sql() - ) + self.assertEqual('SELECT * FROM "users"', builder.to_sql()) def test_merge(self): b1 = self.get_builder() - b1.from_('test').select('foo', 'bar').where('baz', 'boom') + b1.from_("test").select("foo", "bar").where("baz", "boom") b2 = self.get_builder() - b2.where('foo', 'bar') + b2.where("foo", "bar") b1.merge(b2) self.assertEqual( - 'SELECT "foo", "bar" FROM "test" WHERE "baz" = ? AND "foo" = ?', - b1.to_sql() + 'SELECT "foo", "bar" FROM "test" WHERE "baz" = ? AND "foo" = ?', b1.to_sql() ) - self.assertEqual( - ['boom', 'bar'], - b1.get_bindings() - ) + self.assertEqual(["boom", "bar"], b1.get_bindings()) def get_mysql_builder(self): grammar = MySQLQueryGrammar() diff --git a/tests/schema/grammars/__init__.py b/tests/schema/grammars/__init__.py index 633f8661..40a96afc 100644 --- a/tests/schema/grammars/__init__.py +++ b/tests/schema/grammars/__init__.py @@ -1,2 +1 @@ # -*- coding: utf-8 -*- - diff --git a/tests/schema/grammars/test_mysql_grammar.py b/tests/schema/grammars/test_mysql_grammar.py index 80d50c19..b233d0e3 100644 --- a/tests/schema/grammars/test_mysql_grammar.py +++ b/tests/schema/grammars/test_mysql_grammar.py @@ -9,56 +9,57 @@ class MySQLSchemaGrammarTestCase(OratorTestCase): - def tearDown(self): flexmock_teardown() def test_basic_create(self): - blueprint = Blueprint('users') + blueprint = Blueprint("users") blueprint.create() - blueprint.increments('id') - blueprint.string('email') + blueprint.increments("id") + blueprint.string("email") conn = self.get_connection() - conn.should_receive('get_config').once().with_args('charset').and_return('utf8') - conn.should_receive('get_config').once().with_args('collation').and_return('utf8_unicode_ci') + conn.should_receive("get_config").once().with_args("charset").and_return("utf8") + conn.should_receive("get_config").once().with_args("collation").and_return( + "utf8_unicode_ci" + ) statements = blueprint.to_sql(conn, self.get_grammar()) self.assertEqual(1, len(statements)) self.assertEqual( - 'CREATE TABLE `users` (' - '`id` INT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY, ' - '`email` VARCHAR(255) NOT NULL) ' - 'DEFAULT CHARACTER SET utf8 COLLATE utf8_unicode_ci', - statements[0] + "CREATE TABLE `users` (" + "`id` INT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY, " + "`email` VARCHAR(255) NOT NULL) " + "DEFAULT CHARACTER SET utf8 COLLATE utf8_unicode_ci", + statements[0], ) - blueprint = Blueprint('users') + blueprint = Blueprint("users") blueprint.create() - blueprint.increments('id') - blueprint.string('email') + blueprint.increments("id") + blueprint.string("email") conn = self.get_connection() - conn.should_receive('get_config').and_return(None) + conn.should_receive("get_config").and_return(None) statements = blueprint.to_sql(conn, self.get_grammar()) self.assertEqual(1, len(statements)) self.assertEqual( - 'CREATE TABLE `users` (' - '`id` INT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY, ' - '`email` VARCHAR(255) NOT NULL)', - statements[0] + "CREATE TABLE `users` (" + "`id` INT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY, " + "`email` VARCHAR(255) NOT NULL)", + statements[0], ) def test_charset_collation_create(self): - blueprint = Blueprint('users') + blueprint = Blueprint("users") blueprint.create() - blueprint.increments('id') - blueprint.string('email') - blueprint.charset = 'utf8mb4' - blueprint.collation = 'utf8mb4_unicode_ci' + blueprint.increments("id") + blueprint.string("email") + blueprint.charset = "utf8mb4" + blueprint.collation = "utf8mb4_unicode_ci" conn = self.get_connection() @@ -66,540 +67,557 @@ def test_charset_collation_create(self): self.assertEqual(1, len(statements)) self.assertEqual( - 'CREATE TABLE `users` (' - '`id` INT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY, ' - '`email` VARCHAR(255) NOT NULL) ' - 'DEFAULT CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci', - statements[0] + "CREATE TABLE `users` (" + "`id` INT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY, " + "`email` VARCHAR(255) NOT NULL) " + "DEFAULT CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci", + statements[0], ) def test_basic_create_with_prefix(self): - blueprint = Blueprint('users') + blueprint = Blueprint("users") blueprint.create() - blueprint.increments('id') - blueprint.string('email') + blueprint.increments("id") + blueprint.string("email") grammar = self.get_grammar() - grammar.set_table_prefix('prefix_') + grammar.set_table_prefix("prefix_") conn = self.get_connection() - conn.should_receive('get_config').and_return(None) + conn.should_receive("get_config").and_return(None) statements = blueprint.to_sql(conn, grammar) self.assertEqual(1, len(statements)) self.assertEqual( - 'CREATE TABLE `prefix_users` (' - '`id` INT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY, ' - '`email` VARCHAR(255) NOT NULL)', - statements[0] + "CREATE TABLE `prefix_users` (" + "`id` INT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY, " + "`email` VARCHAR(255) NOT NULL)", + statements[0], ) def test_drop_table(self): - blueprint = Blueprint('users') + blueprint = Blueprint("users") blueprint.drop() statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) - self.assertEqual('DROP TABLE `users`', statements[0]) + self.assertEqual("DROP TABLE `users`", statements[0]) def test_drop_table_if_exists(self): - blueprint = Blueprint('users') + blueprint = Blueprint("users") blueprint.drop_if_exists() statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) - self.assertEqual('DROP TABLE IF EXISTS `users`', statements[0]) + self.assertEqual("DROP TABLE IF EXISTS `users`", statements[0]) def test_drop_column(self): - blueprint = Blueprint('users') - blueprint.drop_column('foo') + blueprint = Blueprint("users") + blueprint.drop_column("foo") statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) - self.assertEqual('ALTER TABLE `users` DROP `foo`', statements[0]) + self.assertEqual("ALTER TABLE `users` DROP `foo`", statements[0]) - blueprint = Blueprint('users') - blueprint.drop_column('foo', 'bar') + blueprint = Blueprint("users") + blueprint.drop_column("foo", "bar") statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) - self.assertEqual('ALTER TABLE `users` DROP `foo`, DROP `bar`', statements[0]) + self.assertEqual("ALTER TABLE `users` DROP `foo`, DROP `bar`", statements[0]) def test_drop_primary(self): - blueprint = Blueprint('users') - blueprint.drop_primary('foo') + blueprint = Blueprint("users") + blueprint.drop_primary("foo") statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) - self.assertEqual('ALTER TABLE `users` DROP PRIMARY KEY', statements[0]) + self.assertEqual("ALTER TABLE `users` DROP PRIMARY KEY", statements[0]) def test_drop_unique(self): - blueprint = Blueprint('users') - blueprint.drop_unique('foo') + blueprint = Blueprint("users") + blueprint.drop_unique("foo") statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) - self.assertEqual('ALTER TABLE `users` DROP INDEX foo', statements[0]) + self.assertEqual("ALTER TABLE `users` DROP INDEX foo", statements[0]) def test_drop_index(self): - blueprint = Blueprint('users') - blueprint.drop_index('foo') + blueprint = Blueprint("users") + blueprint.drop_index("foo") statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) - self.assertEqual('ALTER TABLE `users` DROP INDEX foo', statements[0]) + self.assertEqual("ALTER TABLE `users` DROP INDEX foo", statements[0]) def test_drop_foreign(self): - blueprint = Blueprint('users') - blueprint.drop_foreign('foo') + blueprint = Blueprint("users") + blueprint.drop_foreign("foo") statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) - self.assertEqual('ALTER TABLE `users` DROP FOREIGN KEY foo', statements[0]) + self.assertEqual("ALTER TABLE `users` DROP FOREIGN KEY foo", statements[0]) def test_drop_timestamps(self): - blueprint = Blueprint('users') + blueprint = Blueprint("users") blueprint.drop_timestamps() statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) - self.assertEqual('ALTER TABLE `users` DROP `created_at`, DROP `updated_at`', statements[0]) + self.assertEqual( + "ALTER TABLE `users` DROP `created_at`, DROP `updated_at`", statements[0] + ) def test_rename_table(self): - blueprint = Blueprint('users') - blueprint.rename('foo') + blueprint = Blueprint("users") + blueprint.rename("foo") statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) - self.assertEqual('RENAME TABLE `users` TO `foo`', statements[0]) + self.assertEqual("RENAME TABLE `users` TO `foo`", statements[0]) def test_adding_primary_key(self): - blueprint = Blueprint('users') - blueprint.primary('foo', 'bar') + blueprint = Blueprint("users") + blueprint.primary("foo", "bar") statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) - self.assertEqual('ALTER TABLE `users` ADD PRIMARY KEY bar(`foo`)', statements[0]) + self.assertEqual( + "ALTER TABLE `users` ADD PRIMARY KEY bar(`foo`)", statements[0] + ) def test_adding_foreign_key(self): - blueprint = Blueprint('users') - blueprint.foreign('order_id').references('id').on('orders') + blueprint = Blueprint("users") + blueprint.foreign("order_id").references("id").on("orders") statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) expected = [ - 'ALTER TABLE `users` ADD CONSTRAINT users_order_id_foreign ' - 'FOREIGN KEY (`order_id`) REFERENCES `orders` (`id`)' + "ALTER TABLE `users` ADD CONSTRAINT users_order_id_foreign " + "FOREIGN KEY (`order_id`) REFERENCES `orders` (`id`)" ] self.assertEqual(expected, statements) def test_adding_unique_key(self): - blueprint = Blueprint('users') - blueprint.unique('foo', 'bar') + blueprint = Blueprint("users") + blueprint.unique("foo", "bar") statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) - self.assertEqual( - 'ALTER TABLE `users` ADD UNIQUE bar(`foo`)', - statements[0] - ) + self.assertEqual("ALTER TABLE `users` ADD UNIQUE bar(`foo`)", statements[0]) def test_adding_index(self): - blueprint = Blueprint('users') - blueprint.index(['foo', 'bar'], 'baz') + blueprint = Blueprint("users") + blueprint.index(["foo", "bar"], "baz") statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) self.assertEqual( - 'ALTER TABLE `users` ADD INDEX baz(`foo`, `bar`)', - statements[0] + "ALTER TABLE `users` ADD INDEX baz(`foo`, `bar`)", statements[0] ) def test_adding_incrementing_id(self): - blueprint = Blueprint('users') - blueprint.increments('id') + blueprint = Blueprint("users") + blueprint.increments("id") statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) self.assertEqual( - 'ALTER TABLE `users` ADD `id` INT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY', - statements[0] + "ALTER TABLE `users` ADD `id` INT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY", + statements[0], ) def test_adding_big_incrementing_id(self): - blueprint = Blueprint('users') - blueprint.big_increments('id') + blueprint = Blueprint("users") + blueprint.big_increments("id") statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) self.assertEqual( - 'ALTER TABLE `users` ADD `id` BIGINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY', - statements[0] + "ALTER TABLE `users` ADD `id` BIGINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY", + statements[0], ) - + def test_adding_column_after_another(self): - blueprint = Blueprint('users') - blueprint.string('name').after('foo') - + blueprint = Blueprint("users") + blueprint.string("name").after("foo") + statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) self.assertEqual( - 'ALTER TABLE `users` ADD `name` VARCHAR(255) NOT NULL AFTER `foo`', - statements[0] + "ALTER TABLE `users` ADD `name` VARCHAR(255) NOT NULL AFTER `foo`", + statements[0], ) def test_adding_string(self): - blueprint = Blueprint('users') - blueprint.string('foo') + blueprint = Blueprint("users") + blueprint.string("foo") statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) self.assertEqual( - 'ALTER TABLE `users` ADD `foo` VARCHAR(255) NOT NULL', - statements[0] + "ALTER TABLE `users` ADD `foo` VARCHAR(255) NOT NULL", statements[0] ) - blueprint = Blueprint('users') - blueprint.string('foo', 100) + blueprint = Blueprint("users") + blueprint.string("foo", 100) statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) self.assertEqual( - 'ALTER TABLE `users` ADD `foo` VARCHAR(100) NOT NULL', - statements[0] + "ALTER TABLE `users` ADD `foo` VARCHAR(100) NOT NULL", statements[0] ) - blueprint = Blueprint('users') - blueprint.string('foo', 100).nullable().default('bar') + blueprint = Blueprint("users") + blueprint.string("foo", 100).nullable().default("bar") statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) self.assertEqual( - 'ALTER TABLE `users` ADD `foo` VARCHAR(100) NULL DEFAULT \'bar\'', - statements[0] + "ALTER TABLE `users` ADD `foo` VARCHAR(100) NULL DEFAULT 'bar'", + statements[0], ) def test_adding_text(self): - blueprint = Blueprint('users') - blueprint.text('foo') + blueprint = Blueprint("users") + blueprint.text("foo") statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) - self.assertEqual( - 'ALTER TABLE `users` ADD `foo` TEXT NOT NULL', - statements[0] - ) + self.assertEqual("ALTER TABLE `users` ADD `foo` TEXT NOT NULL", statements[0]) def test_adding_big_integer(self): - blueprint = Blueprint('users') - blueprint.big_integer('foo') + blueprint = Blueprint("users") + blueprint.big_integer("foo") statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) - self.assertEqual( - 'ALTER TABLE `users` ADD `foo` BIGINT NOT NULL', - statements[0] - ) + self.assertEqual("ALTER TABLE `users` ADD `foo` BIGINT NOT NULL", statements[0]) - blueprint = Blueprint('users') - blueprint.big_integer('foo', True) + blueprint = Blueprint("users") + blueprint.big_integer("foo", True) statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) self.assertEqual( - 'ALTER TABLE `users` ADD `foo` BIGINT NOT NULL AUTO_INCREMENT PRIMARY KEY', - statements[0] + "ALTER TABLE `users` ADD `foo` BIGINT NOT NULL AUTO_INCREMENT PRIMARY KEY", + statements[0], ) def test_adding_integer(self): - blueprint = Blueprint('users') - blueprint.integer('foo') + blueprint = Blueprint("users") + blueprint.integer("foo") statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) - self.assertEqual( - 'ALTER TABLE `users` ADD `foo` INT NOT NULL', - statements[0] - ) + self.assertEqual("ALTER TABLE `users` ADD `foo` INT NOT NULL", statements[0]) - blueprint = Blueprint('users') - blueprint.integer('foo', True) + blueprint = Blueprint("users") + blueprint.integer("foo", True) statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) self.assertEqual( - 'ALTER TABLE `users` ADD `foo` INT NOT NULL AUTO_INCREMENT PRIMARY KEY', - statements[0] + "ALTER TABLE `users` ADD `foo` INT NOT NULL AUTO_INCREMENT PRIMARY KEY", + statements[0], ) def test_adding_medium_integer(self): - blueprint = Blueprint('users') - blueprint.medium_integer('foo') + blueprint = Blueprint("users") + blueprint.medium_integer("foo") statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) self.assertEqual( - 'ALTER TABLE `users` ADD `foo` MEDIUMINT NOT NULL', - statements[0] + "ALTER TABLE `users` ADD `foo` MEDIUMINT NOT NULL", statements[0] ) - blueprint = Blueprint('users') - blueprint.medium_integer('foo', True) + blueprint = Blueprint("users") + blueprint.medium_integer("foo", True) statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) self.assertEqual( - 'ALTER TABLE `users` ADD `foo` MEDIUMINT NOT NULL AUTO_INCREMENT PRIMARY KEY', - statements[0] + "ALTER TABLE `users` ADD `foo` MEDIUMINT NOT NULL AUTO_INCREMENT PRIMARY KEY", + statements[0], ) def test_adding_tiny_integer(self): - blueprint = Blueprint('users') - blueprint.tiny_integer('foo') + blueprint = Blueprint("users") + blueprint.tiny_integer("foo") statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) self.assertEqual( - 'ALTER TABLE `users` ADD `foo` TINYINT NOT NULL', - statements[0] + "ALTER TABLE `users` ADD `foo` TINYINT NOT NULL", statements[0] ) - blueprint = Blueprint('users') - blueprint.tiny_integer('foo', True) + blueprint = Blueprint("users") + blueprint.tiny_integer("foo", True) statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) self.assertEqual( - 'ALTER TABLE `users` ADD `foo` TINYINT NOT NULL AUTO_INCREMENT PRIMARY KEY', - statements[0] + "ALTER TABLE `users` ADD `foo` TINYINT NOT NULL AUTO_INCREMENT PRIMARY KEY", + statements[0], ) def test_adding_small_integer(self): - blueprint = Blueprint('users') - blueprint.small_integer('foo') + blueprint = Blueprint("users") + blueprint.small_integer("foo") statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) self.assertEqual( - 'ALTER TABLE `users` ADD `foo` SMALLINT NOT NULL', - statements[0] + "ALTER TABLE `users` ADD `foo` SMALLINT NOT NULL", statements[0] ) - blueprint = Blueprint('users') - blueprint.small_integer('foo', True) + blueprint = Blueprint("users") + blueprint.small_integer("foo", True) statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) self.assertEqual( - 'ALTER TABLE `users` ADD `foo` SMALLINT NOT NULL AUTO_INCREMENT PRIMARY KEY', - statements[0] + "ALTER TABLE `users` ADD `foo` SMALLINT NOT NULL AUTO_INCREMENT PRIMARY KEY", + statements[0], ) def test_adding_float(self): - blueprint = Blueprint('users') - blueprint.float('foo', 5, 2) + blueprint = Blueprint("users") + blueprint.float("foo", 5, 2) statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) self.assertEqual( - 'ALTER TABLE `users` ADD `foo` DOUBLE(5, 2) NOT NULL', - statements[0] + "ALTER TABLE `users` ADD `foo` DOUBLE(5, 2) NOT NULL", statements[0] ) def test_adding_double(self): - blueprint = Blueprint('users') - blueprint.double('foo') + blueprint = Blueprint("users") + blueprint.double("foo") statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) - self.assertEqual( - 'ALTER TABLE `users` ADD `foo` DOUBLE NOT NULL', - statements[0] - ) + self.assertEqual("ALTER TABLE `users` ADD `foo` DOUBLE NOT NULL", statements[0]) def test_adding_double_with_precision(self): - blueprint = Blueprint('users') - blueprint.double('foo', 15, 8) + blueprint = Blueprint("users") + blueprint.double("foo", 15, 8) statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) self.assertEqual( - 'ALTER TABLE `users` ADD `foo` DOUBLE(15, 8) NOT NULL', - statements[0] + "ALTER TABLE `users` ADD `foo` DOUBLE(15, 8) NOT NULL", statements[0] ) def test_adding_decimal(self): - blueprint = Blueprint('users') - blueprint.decimal('foo', 5, 2) + blueprint = Blueprint("users") + blueprint.decimal("foo", 5, 2) statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) self.assertEqual( - 'ALTER TABLE `users` ADD `foo` DECIMAL(5, 2) NOT NULL', - statements[0] + "ALTER TABLE `users` ADD `foo` DECIMAL(5, 2) NOT NULL", statements[0] ) def test_adding_boolean(self): - blueprint = Blueprint('users') - blueprint.boolean('foo') + blueprint = Blueprint("users") + blueprint.boolean("foo") statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) self.assertEqual( - 'ALTER TABLE `users` ADD `foo` TINYINT(1) NOT NULL', - statements[0] + "ALTER TABLE `users` ADD `foo` TINYINT(1) NOT NULL", statements[0] ) def test_adding_enum(self): - blueprint = Blueprint('users') - blueprint.enum('foo', ['bar', 'baz']) + blueprint = Blueprint("users") + blueprint.enum("foo", ["bar", "baz"]) statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) self.assertEqual( - 'ALTER TABLE `users` ADD `foo` ENUM(\'bar\', \'baz\') NOT NULL', - statements[0] + "ALTER TABLE `users` ADD `foo` ENUM('bar', 'baz') NOT NULL", statements[0] ) def test_adding_date(self): - blueprint = Blueprint('users') - blueprint.date('foo') + blueprint = Blueprint("users") + blueprint.date("foo") statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) - self.assertEqual( - 'ALTER TABLE `users` ADD `foo` DATE NOT NULL', - statements[0] - ) + self.assertEqual("ALTER TABLE `users` ADD `foo` DATE NOT NULL", statements[0]) def test_adding_datetime(self): - blueprint = Blueprint('users') - blueprint.datetime('foo') + blueprint = Blueprint("users") + blueprint.datetime("foo") statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) self.assertEqual( - 'ALTER TABLE `users` ADD `foo` DATETIME NOT NULL', - statements[0] + "ALTER TABLE `users` ADD `foo` DATETIME NOT NULL", statements[0] ) def test_adding_time(self): - blueprint = Blueprint('users') - blueprint.time('foo') + blueprint = Blueprint("users") + blueprint.time("foo") statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + self.assertEqual(1, len(statements)) + self.assertEqual("ALTER TABLE `users` ADD `foo` TIME NOT NULL", statements[0]) + + def test_adding_timestamp_mysql_lt_564(self): + blueprint = Blueprint("users") + blueprint.timestamp("foo") + statements = blueprint.to_sql( + self.get_connection(), self.get_grammar((5, 6, 0, "")) + ) + self.assertEqual(1, len(statements)) self.assertEqual( - 'ALTER TABLE `users` ADD `foo` TIME NOT NULL', - statements[0] + "ALTER TABLE `users` ADD `foo` TIMESTAMP NOT NULL", statements[0] ) - def test_adding_timestamp(self): - blueprint = Blueprint('users') - blueprint.timestamp('foo') - statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + def test_adding_timestamp_mysql_gte_564(self): + blueprint = Blueprint("users") + blueprint.timestamp("foo") + statements = blueprint.to_sql( + self.get_connection(), self.get_grammar((5, 6, 4, "")) + ) self.assertEqual(1, len(statements)) self.assertEqual( - 'ALTER TABLE `users` ADD `foo` TIMESTAMP NOT NULL', - statements[0] + "ALTER TABLE `users` ADD `foo` TIMESTAMP(6) NOT NULL", statements[0] ) - def test_adding_timestamp_with_current(self): - blueprint = Blueprint('users') - blueprint.timestamp('foo').use_current() - statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + def test_adding_timestamp_with_current_mysql_lt_564(self): + blueprint = Blueprint("users") + blueprint.timestamp("foo").use_current() + statements = blueprint.to_sql( + self.get_connection(), self.get_grammar((5, 6, 0, "")) + ) + + self.assertEqual(1, len(statements)) + self.assertEqual( + "ALTER TABLE `users` ADD `foo` TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL", + statements[0], + ) + + def test_adding_timestamp_with_current_mysql_gte_564(self): + blueprint = Blueprint("users") + blueprint.timestamp("foo").use_current() + statements = blueprint.to_sql( + self.get_connection(), self.get_grammar((5, 6, 4, "")) + ) self.assertEqual(1, len(statements)) self.assertEqual( - 'ALTER TABLE `users` ADD `foo` TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL', - statements[0] + "ALTER TABLE `users` ADD `foo` TIMESTAMP(6) DEFAULT CURRENT_TIMESTAMP(6) NOT NULL", + statements[0], ) - def test_adding_timestamps(self): - blueprint = Blueprint('users') + def test_adding_timestamps_mysql_lt_564(self): + blueprint = Blueprint("users") blueprint.timestamps() - statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + statements = blueprint.to_sql( + self.get_connection(), self.get_grammar((5, 6, 0, "")) + ) self.assertEqual(1, len(statements)) expected = [ - 'ALTER TABLE `users` ADD `created_at` TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL, ' - 'ADD `updated_at` TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL' + "ALTER TABLE `users` ADD `created_at` TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL, " + "ADD `updated_at` TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL" ] - self.assertEqual( - expected[0], - statements[0] + self.assertEqual(expected[0], statements[0]) + + def test_adding_timestamps_mysql_gte_564(self): + blueprint = Blueprint("users") + blueprint.timestamps() + statements = blueprint.to_sql( + self.get_connection(), self.get_grammar((5, 6, 4, "")) ) - def test_adding_timestamps_not_current(self): - blueprint = Blueprint('users') + self.assertEqual(1, len(statements)) + expected = [ + "ALTER TABLE `users` ADD `created_at` TIMESTAMP(6) DEFAULT CURRENT_TIMESTAMP(6) NOT NULL, " + "ADD `updated_at` TIMESTAMP(6) DEFAULT CURRENT_TIMESTAMP(6) NOT NULL" + ] + self.assertEqual(expected[0], statements[0]) + + def test_adding_timestamps_not_current_mysql_lt_564(self): + blueprint = Blueprint("users") blueprint.timestamps(use_current=False) - statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) + statements = blueprint.to_sql( + self.get_connection(), self.get_grammar((5, 6, 0, "")) + ) self.assertEqual(1, len(statements)) expected = [ - 'ALTER TABLE `users` ADD `created_at` TIMESTAMP NOT NULL, ' - 'ADD `updated_at` TIMESTAMP NOT NULL' + "ALTER TABLE `users` ADD `created_at` TIMESTAMP NOT NULL, " + "ADD `updated_at` TIMESTAMP NOT NULL" ] - self.assertEqual( - expected[0], - statements[0] + self.assertEqual(expected[0], statements[0]) + + def test_adding_timestamps_not_current_mysql_gte_564(self): + blueprint = Blueprint("users") + blueprint.timestamps(use_current=False) + statements = blueprint.to_sql( + self.get_connection(), self.get_grammar((5, 6, 4, "")) ) + self.assertEqual(1, len(statements)) + expected = [ + "ALTER TABLE `users` ADD `created_at` TIMESTAMP(6) NOT NULL, " + "ADD `updated_at` TIMESTAMP(6) NOT NULL" + ] + self.assertEqual(expected[0], statements[0]) + def test_adding_binary(self): - blueprint = Blueprint('users') - blueprint.binary('foo') + blueprint = Blueprint("users") + blueprint.binary("foo") statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) - self.assertEqual( - 'ALTER TABLE `users` ADD `foo` BLOB NOT NULL', - statements[0] - ) + self.assertEqual("ALTER TABLE `users` ADD `foo` BLOB NOT NULL", statements[0]) def test_adding_json(self): - blueprint = Blueprint('users') - blueprint.json('foo') + blueprint = Blueprint("users") + blueprint.json("foo") statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) - self.assertEqual( - 'ALTER TABLE `users` ADD `foo` JSON NOT NULL', - statements[0] - ) + self.assertEqual("ALTER TABLE `users` ADD `foo` JSON NOT NULL", statements[0]) def test_adding_json_mysql_56(self): - blueprint = Blueprint('users') - blueprint.json('foo') + blueprint = Blueprint("users") + blueprint.json("foo") - statements = blueprint.to_sql(self.get_connection(), self.get_grammar((5, 6, 0, ''))) + statements = blueprint.to_sql( + self.get_connection(), self.get_grammar((5, 6, 0, "")) + ) self.assertEqual(1, len(statements)) - self.assertEqual( - 'ALTER TABLE `users` ADD `foo` TEXT NOT NULL', - statements[0] - ) + self.assertEqual("ALTER TABLE `users` ADD `foo` TEXT NOT NULL", statements[0]) def test_adding_json_mariadb(self): - blueprint = Blueprint('users') - blueprint.json('foo') + blueprint = Blueprint("users") + blueprint.json("foo") - statements = blueprint.to_sql(self.get_connection(), self.get_grammar((10, 6, 0, 'mariadb'))) + statements = blueprint.to_sql( + self.get_connection(), self.get_grammar((10, 6, 0, "mariadb")) + ) self.assertEqual(1, len(statements)) - self.assertEqual( - 'ALTER TABLE `users` ADD `foo` TEXT NOT NULL', - statements[0] - ) + self.assertEqual("ALTER TABLE `users` ADD `foo` TEXT NOT NULL", statements[0]) def get_connection(self, version=None): if version is None: - version = (5, 7, 11, '') + version = (5, 7, 11, "") connector = flexmock(MySQLConnector()) - connector.should_receive('get_server_version').and_return(version) + connector.should_receive("get_server_version").and_return(version) conn = flexmock(Connection(connector)) return conn diff --git a/tests/schema/grammars/test_postgres_grammar.py b/tests/schema/grammars/test_postgres_grammar.py index f7778d7a..e1bf70b1 100644 --- a/tests/schema/grammars/test_postgres_grammar.py +++ b/tests/schema/grammars/test_postgres_grammar.py @@ -8,26 +8,25 @@ class PostgresSchemaGrammarTestCase(OratorTestCase): - def tearDown(self): flexmock_teardown() def test_basic_create(self): - blueprint = Blueprint('users') + blueprint = Blueprint("users") blueprint.create() - blueprint.increments('id') - blueprint.string('email') + blueprint.increments("id") + blueprint.string("email") statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) self.assertEqual( 'CREATE TABLE "users" ("id" SERIAL PRIMARY KEY NOT NULL, "email" VARCHAR(255) NOT NULL)', - statements[0] + statements[0], ) - blueprint = Blueprint('users') - blueprint.increments('id') - blueprint.string('email') + blueprint = Blueprint("users") + blueprint.increments("id") + blueprint.string("email") statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) @@ -38,7 +37,7 @@ def test_basic_create(self): self.assertEqual(expected[0], statements[0]) def test_drop_table(self): - blueprint = Blueprint('users') + blueprint = Blueprint("users") blueprint.drop() statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) @@ -46,7 +45,7 @@ def test_drop_table(self): self.assertEqual('DROP TABLE "users"', statements[0]) def test_drop_table_if_exists(self): - blueprint = Blueprint('users') + blueprint = Blueprint("users") blueprint.drop_if_exists() statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) @@ -54,82 +53,89 @@ def test_drop_table_if_exists(self): self.assertEqual('DROP TABLE IF EXISTS "users"', statements[0]) def test_drop_column(self): - blueprint = Blueprint('users') - blueprint.drop_column('foo') + blueprint = Blueprint("users") + blueprint.drop_column("foo") statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) self.assertEqual('ALTER TABLE "users" DROP COLUMN "foo"', statements[0]) - blueprint = Blueprint('users') - blueprint.drop_column('foo', 'bar') + blueprint = Blueprint("users") + blueprint.drop_column("foo", "bar") statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) - self.assertEqual('ALTER TABLE "users" DROP COLUMN "foo", DROP COLUMN "bar"', statements[0]) + self.assertEqual( + 'ALTER TABLE "users" DROP COLUMN "foo", DROP COLUMN "bar"', statements[0] + ) def test_drop_primary(self): - blueprint = Blueprint('users') - blueprint.drop_primary('foo') + blueprint = Blueprint("users") + blueprint.drop_primary("foo") statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) - self.assertEqual('ALTER TABLE "users" DROP CONSTRAINT users_pkey', statements[0]) + self.assertEqual( + 'ALTER TABLE "users" DROP CONSTRAINT users_pkey', statements[0] + ) def test_drop_unique(self): - blueprint = Blueprint('users') - blueprint.drop_unique('foo') + blueprint = Blueprint("users") + blueprint.drop_unique("foo") statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) self.assertEqual('ALTER TABLE "users" DROP CONSTRAINT foo', statements[0]) def test_drop_index(self): - blueprint = Blueprint('users') - blueprint.drop_index('foo') + blueprint = Blueprint("users") + blueprint.drop_index("foo") statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) - self.assertEqual('DROP INDEX foo', statements[0]) + self.assertEqual("DROP INDEX foo", statements[0]) def test_drop_foreign(self): - blueprint = Blueprint('users') - blueprint.drop_unique('foo') + blueprint = Blueprint("users") + blueprint.drop_unique("foo") statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) self.assertEqual('ALTER TABLE "users" DROP CONSTRAINT foo', statements[0]) def test_drop_timestamps(self): - blueprint = Blueprint('users') + blueprint = Blueprint("users") blueprint.drop_timestamps() statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) - self.assertEqual('ALTER TABLE "users" DROP COLUMN "created_at", DROP COLUMN "updated_at"', statements[0]) + self.assertEqual( + 'ALTER TABLE "users" DROP COLUMN "created_at", DROP COLUMN "updated_at"', + statements[0], + ) def test_rename_table(self): - blueprint = Blueprint('users') - blueprint.rename('foo') + blueprint = Blueprint("users") + blueprint.rename("foo") statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) self.assertEqual('ALTER TABLE "users" RENAME TO "foo"', statements[0]) def test_adding_primary_key(self): - blueprint = Blueprint('users') - blueprint.primary('foo') + blueprint = Blueprint("users") + blueprint.primary("foo") statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) self.assertEqual('ALTER TABLE "users" ADD PRIMARY KEY ("foo")', statements[0]) def test_adding_foreign_key(self): - blueprint = Blueprint('users') + blueprint = Blueprint("users") blueprint.create() - blueprint.string('foo').primary() - blueprint.string('order_id') - blueprint.foreign('order_id').references('id').on('orders') + blueprint.string("foo").primary() + blueprint.string("order_id") + blueprint.foreign("order_id").references("id").on("orders") statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(3, len(statements)) @@ -137,362 +143,341 @@ def test_adding_foreign_key(self): 'CREATE TABLE "users" ("foo" VARCHAR(255) NOT NULL, "order_id" VARCHAR(255) NOT NULL)', 'ALTER TABLE "users" ADD CONSTRAINT users_order_id_foreign' ' FOREIGN KEY ("order_id") REFERENCES "orders" ("id")', - 'ALTER TABLE "users" ADD PRIMARY KEY ("foo")' + 'ALTER TABLE "users" ADD PRIMARY KEY ("foo")', ] self.assertEqual(expected, statements) def test_adding_unique_key(self): - blueprint = Blueprint('users') - blueprint.unique('foo', 'bar') + blueprint = Blueprint("users") + blueprint.unique("foo", "bar") statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) self.assertEqual( - 'ALTER TABLE "users" ADD CONSTRAINT bar UNIQUE ("foo")', - statements[0] + 'ALTER TABLE "users" ADD CONSTRAINT bar UNIQUE ("foo")', statements[0] ) def test_adding_index(self): - blueprint = Blueprint('users') - blueprint.index(['foo', 'bar'], 'baz') + blueprint = Blueprint("users") + blueprint.index(["foo", "bar"], "baz") statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) - self.assertEqual( - 'CREATE INDEX baz ON "users" ("foo", "bar")', - statements[0] - ) + self.assertEqual('CREATE INDEX baz ON "users" ("foo", "bar")', statements[0]) def test_adding_incrementing_id(self): - blueprint = Blueprint('users') - blueprint.increments('id') + blueprint = Blueprint("users") + blueprint.increments("id") statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) self.assertEqual( 'ALTER TABLE "users" ADD COLUMN "id" SERIAL PRIMARY KEY NOT NULL', - statements[0] + statements[0], ) def test_adding_big_incrementing_id(self): - blueprint = Blueprint('users') - blueprint.big_increments('id') + blueprint = Blueprint("users") + blueprint.big_increments("id") statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) self.assertEqual( 'ALTER TABLE "users" ADD COLUMN "id" BIGSERIAL PRIMARY KEY NOT NULL', - statements[0] + statements[0], ) def test_adding_string(self): - blueprint = Blueprint('users') - blueprint.string('foo') + blueprint = Blueprint("users") + blueprint.string("foo") statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) self.assertEqual( - 'ALTER TABLE "users" ADD COLUMN "foo" VARCHAR(255) NOT NULL', - statements[0] + 'ALTER TABLE "users" ADD COLUMN "foo" VARCHAR(255) NOT NULL', statements[0] ) - blueprint = Blueprint('users') - blueprint.string('foo', 100) + blueprint = Blueprint("users") + blueprint.string("foo", 100) statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) self.assertEqual( - 'ALTER TABLE "users" ADD COLUMN "foo" VARCHAR(100) NOT NULL', - statements[0] + 'ALTER TABLE "users" ADD COLUMN "foo" VARCHAR(100) NOT NULL', statements[0] ) - blueprint = Blueprint('users') - blueprint.string('foo', 100).nullable().default('bar') + blueprint = Blueprint("users") + blueprint.string("foo", 100).nullable().default("bar") statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) self.assertEqual( 'ALTER TABLE "users" ADD COLUMN "foo" VARCHAR(100) NULL DEFAULT \'bar\'', - statements[0] + statements[0], ) def test_adding_text(self): - blueprint = Blueprint('users') - blueprint.text('foo') + blueprint = Blueprint("users") + blueprint.text("foo") statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) self.assertEqual( - 'ALTER TABLE "users" ADD COLUMN "foo" TEXT NOT NULL', - statements[0] + 'ALTER TABLE "users" ADD COLUMN "foo" TEXT NOT NULL', statements[0] ) def test_adding_big_integer(self): - blueprint = Blueprint('users') - blueprint.big_integer('foo') + blueprint = Blueprint("users") + blueprint.big_integer("foo") statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) self.assertEqual( - 'ALTER TABLE "users" ADD COLUMN "foo" BIGINT NOT NULL', - statements[0] + 'ALTER TABLE "users" ADD COLUMN "foo" BIGINT NOT NULL', statements[0] ) - blueprint = Blueprint('users') - blueprint.big_integer('foo', True) + blueprint = Blueprint("users") + blueprint.big_integer("foo", True) statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) self.assertEqual( 'ALTER TABLE "users" ADD COLUMN "foo" BIGSERIAL PRIMARY KEY NOT NULL', - statements[0] + statements[0], ) def test_adding_integer(self): - blueprint = Blueprint('users') - blueprint.integer('foo') + blueprint = Blueprint("users") + blueprint.integer("foo") statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) self.assertEqual( - 'ALTER TABLE "users" ADD COLUMN "foo" INTEGER NOT NULL', - statements[0] + 'ALTER TABLE "users" ADD COLUMN "foo" INTEGER NOT NULL', statements[0] ) - blueprint = Blueprint('users') - blueprint.integer('foo', True) + blueprint = Blueprint("users") + blueprint.integer("foo", True) statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) self.assertEqual( 'ALTER TABLE "users" ADD COLUMN "foo" SERIAL PRIMARY KEY NOT NULL', - statements[0] + statements[0], ) def test_adding_medium_integer(self): - blueprint = Blueprint('users') - blueprint.medium_integer('foo') + blueprint = Blueprint("users") + blueprint.medium_integer("foo") statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) self.assertEqual( - 'ALTER TABLE "users" ADD COLUMN "foo" INTEGER NOT NULL', - statements[0] + 'ALTER TABLE "users" ADD COLUMN "foo" INTEGER NOT NULL', statements[0] ) - blueprint = Blueprint('users') - blueprint.medium_integer('foo', True) + blueprint = Blueprint("users") + blueprint.medium_integer("foo", True) statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) self.assertEqual( 'ALTER TABLE "users" ADD COLUMN "foo" SERIAL PRIMARY KEY NOT NULL', - statements[0] + statements[0], ) def test_adding_tiny_integer(self): - blueprint = Blueprint('users') - blueprint.tiny_integer('foo') + blueprint = Blueprint("users") + blueprint.tiny_integer("foo") statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) self.assertEqual( - 'ALTER TABLE "users" ADD COLUMN "foo" SMALLINT NOT NULL', - statements[0] + 'ALTER TABLE "users" ADD COLUMN "foo" SMALLINT NOT NULL', statements[0] ) - blueprint = Blueprint('users') - blueprint.tiny_integer('foo', True) + blueprint = Blueprint("users") + blueprint.tiny_integer("foo", True) statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) self.assertEqual( 'ALTER TABLE "users" ADD COLUMN "foo" SMALLSERIAL PRIMARY KEY NOT NULL', - statements[0] + statements[0], ) def test_adding_small_integer(self): - blueprint = Blueprint('users') - blueprint.small_integer('foo') + blueprint = Blueprint("users") + blueprint.small_integer("foo") statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) self.assertEqual( - 'ALTER TABLE "users" ADD COLUMN "foo" SMALLINT NOT NULL', - statements[0] + 'ALTER TABLE "users" ADD COLUMN "foo" SMALLINT NOT NULL', statements[0] ) - blueprint = Blueprint('users') - blueprint.small_integer('foo', True) + blueprint = Blueprint("users") + blueprint.small_integer("foo", True) statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) self.assertEqual( 'ALTER TABLE "users" ADD COLUMN "foo" SMALLSERIAL PRIMARY KEY NOT NULL', - statements[0] + statements[0], ) def test_adding_float(self): - blueprint = Blueprint('users') - blueprint.float('foo', 5, 2) + blueprint = Blueprint("users") + blueprint.float("foo", 5, 2) statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) self.assertEqual( 'ALTER TABLE "users" ADD COLUMN "foo" DOUBLE PRECISION NOT NULL', - statements[0] + statements[0], ) def test_adding_double(self): - blueprint = Blueprint('users') - blueprint.double('foo', 15, 8) + blueprint = Blueprint("users") + blueprint.double("foo", 15, 8) statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) self.assertEqual( 'ALTER TABLE "users" ADD COLUMN "foo" DOUBLE PRECISION NOT NULL', - statements[0] + statements[0], ) def test_adding_decimal(self): - blueprint = Blueprint('users') - blueprint.decimal('foo', 5, 2) + blueprint = Blueprint("users") + blueprint.decimal("foo", 5, 2) statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) self.assertEqual( - 'ALTER TABLE "users" ADD COLUMN "foo" DECIMAL(5, 2) NOT NULL', - statements[0] + 'ALTER TABLE "users" ADD COLUMN "foo" DECIMAL(5, 2) NOT NULL', statements[0] ) def test_adding_boolean(self): - blueprint = Blueprint('users') - blueprint.boolean('foo') + blueprint = Blueprint("users") + blueprint.boolean("foo") statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) self.assertEqual( - 'ALTER TABLE "users" ADD COLUMN "foo" BOOLEAN NOT NULL', - statements[0] + 'ALTER TABLE "users" ADD COLUMN "foo" BOOLEAN NOT NULL', statements[0] ) def test_adding_enum(self): - blueprint = Blueprint('users') - blueprint.enum('foo', ['bar', 'baz']) + blueprint = Blueprint("users") + blueprint.enum("foo", ["bar", "baz"]) statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) self.assertEqual( 'ALTER TABLE "users" ADD COLUMN "foo" VARCHAR(255) CHECK ("foo" IN (\'bar\', \'baz\')) NOT NULL', - statements[0] + statements[0], ) def test_adding_date(self): - blueprint = Blueprint('users') - blueprint.date('foo') + blueprint = Blueprint("users") + blueprint.date("foo") statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) self.assertEqual( - 'ALTER TABLE "users" ADD COLUMN "foo" DATE NOT NULL', - statements[0] + 'ALTER TABLE "users" ADD COLUMN "foo" DATE NOT NULL', statements[0] ) def test_adding_datetime(self): - blueprint = Blueprint('users') - blueprint.datetime('foo') + blueprint = Blueprint("users") + blueprint.datetime("foo") statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) self.assertEqual( - 'ALTER TABLE "users" ADD COLUMN "foo" TIMESTAMP(0) WITHOUT TIME ZONE NOT NULL', - statements[0] + 'ALTER TABLE "users" ADD COLUMN "foo" TIMESTAMP(6) WITHOUT TIME ZONE NOT NULL', + statements[0], ) def test_adding_time(self): - blueprint = Blueprint('users') - blueprint.time('foo') + blueprint = Blueprint("users") + blueprint.time("foo") statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) self.assertEqual( - 'ALTER TABLE "users" ADD COLUMN "foo" TIME(0) WITHOUT TIME ZONE NOT NULL', - statements[0] + 'ALTER TABLE "users" ADD COLUMN "foo" TIME(6) WITHOUT TIME ZONE NOT NULL', + statements[0], ) def test_adding_timestamp(self): - blueprint = Blueprint('users') - blueprint.timestamp('foo') + blueprint = Blueprint("users") + blueprint.timestamp("foo") statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) self.assertEqual( - 'ALTER TABLE "users" ADD COLUMN "foo" TIMESTAMP(0) WITHOUT TIME ZONE NOT NULL', - statements[0] + 'ALTER TABLE "users" ADD COLUMN "foo" TIMESTAMP(6) WITHOUT TIME ZONE NOT NULL', + statements[0], ) def test_adding_timestamp_with_current(self): - blueprint = Blueprint('users') - blueprint.timestamp('foo').use_current() + blueprint = Blueprint("users") + blueprint.timestamp("foo").use_current() statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) self.assertEqual( - 'ALTER TABLE "users" ADD COLUMN "foo" TIMESTAMP(0) WITHOUT TIME ZONE DEFAULT CURRENT_TIMESTAMP(0) NOT NULL', - statements[0] + 'ALTER TABLE "users" ADD COLUMN "foo" TIMESTAMP(6) WITHOUT TIME ZONE ' + "DEFAULT CURRENT_TIMESTAMP(6) NOT NULL", + statements[0], ) def test_adding_timestamps(self): - blueprint = Blueprint('users') + blueprint = Blueprint("users") blueprint.timestamps() statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) expected = [ - 'ALTER TABLE "users" ADD COLUMN "created_at" TIMESTAMP(0) WITHOUT TIME ZONE DEFAULT CURRENT_TIMESTAMP(0) NOT NULL, ' - 'ADD COLUMN "updated_at" TIMESTAMP(0) WITHOUT TIME ZONE DEFAULT CURRENT_TIMESTAMP(0) NOT NULL' + 'ALTER TABLE "users" ADD COLUMN "created_at" TIMESTAMP(6) WITHOUT TIME ZONE DEFAULT CURRENT_TIMESTAMP(6) NOT NULL, ' + 'ADD COLUMN "updated_at" TIMESTAMP(6) WITHOUT TIME ZONE DEFAULT CURRENT_TIMESTAMP(6) NOT NULL' ] - self.assertEqual( - expected[0], - statements[0] - ) + + self.assertEqual(expected[0], statements[0]) def test_adding_timestamps_not_current(self): - blueprint = Blueprint('users') + blueprint = Blueprint("users") blueprint.timestamps(use_current=False) statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) expected = [ - 'ALTER TABLE "users" ADD COLUMN "created_at" TIMESTAMP(0) WITHOUT TIME ZONE NOT NULL, ' - 'ADD COLUMN "updated_at" TIMESTAMP(0) WITHOUT TIME ZONE NOT NULL' + 'ALTER TABLE "users" ADD COLUMN "created_at" TIMESTAMP(6) WITHOUT TIME ZONE NOT NULL, ' + 'ADD COLUMN "updated_at" TIMESTAMP(6) WITHOUT TIME ZONE NOT NULL' ] - self.assertEqual( - expected[0], - statements[0] - ) + self.assertEqual(expected[0], statements[0]) def test_adding_binary(self): - blueprint = Blueprint('users') - blueprint.binary('foo') + blueprint = Blueprint("users") + blueprint.binary("foo") statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) self.assertEqual( - 'ALTER TABLE "users" ADD COLUMN "foo" BYTEA NOT NULL', - statements[0] + 'ALTER TABLE "users" ADD COLUMN "foo" BYTEA NOT NULL', statements[0] ) def test_adding_json(self): - blueprint = Blueprint('users') - blueprint.json('foo') + blueprint = Blueprint("users") + blueprint.json("foo") statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) self.assertEqual( - 'ALTER TABLE "users" ADD COLUMN "foo" JSON NOT NULL', - statements[0] + 'ALTER TABLE "users" ADD COLUMN "foo" JSON NOT NULL', statements[0] ) def get_connection(self): diff --git a/tests/schema/grammars/test_sqlite_grammar.py b/tests/schema/grammars/test_sqlite_grammar.py index b646f55c..a5ed7cf0 100644 --- a/tests/schema/grammars/test_sqlite_grammar.py +++ b/tests/schema/grammars/test_sqlite_grammar.py @@ -8,37 +8,36 @@ class SqliteSchemaGrammarTestCase(OratorTestCase): - def tearDown(self): flexmock_teardown() def test_basic_create(self): - blueprint = Blueprint('users') + blueprint = Blueprint("users") blueprint.create() - blueprint.increments('id') - blueprint.string('email') + blueprint.increments("id") + blueprint.string("email") statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) self.assertEqual( 'CREATE TABLE "users" ("id" INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, "email" VARCHAR NOT NULL)', - statements[0] + statements[0], ) - blueprint = Blueprint('users') - blueprint.increments('id') - blueprint.string('email') + blueprint = Blueprint("users") + blueprint.increments("id") + blueprint.string("email") statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(2, len(statements)) expected = [ 'ALTER TABLE "users" ADD COLUMN "id" INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT', - 'ALTER TABLE "users" ADD COLUMN "email" VARCHAR NOT NULL' + 'ALTER TABLE "users" ADD COLUMN "email" VARCHAR NOT NULL', ] self.assertEqual(expected, statements) def test_drop_table(self): - blueprint = Blueprint('users') + blueprint = Blueprint("users") blueprint.drop() statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) @@ -46,7 +45,7 @@ def test_drop_table(self): self.assertEqual('DROP TABLE "users"', statements[0]) def test_drop_table_if_exists(self): - blueprint = Blueprint('users') + blueprint = Blueprint("users") blueprint.drop_if_exists() statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) @@ -54,364 +53,335 @@ def test_drop_table_if_exists(self): self.assertEqual('DROP TABLE IF EXISTS "users"', statements[0]) def test_drop_unique(self): - blueprint = Blueprint('users') - blueprint.drop_unique('foo') + blueprint = Blueprint("users") + blueprint.drop_unique("foo") statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) - self.assertEqual('DROP INDEX foo', statements[0]) + self.assertEqual("DROP INDEX foo", statements[0]) def test_drop_index(self): - blueprint = Blueprint('users') - blueprint.drop_index('foo') + blueprint = Blueprint("users") + blueprint.drop_index("foo") statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) - self.assertEqual('DROP INDEX foo', statements[0]) + self.assertEqual("DROP INDEX foo", statements[0]) def test_rename_table(self): - blueprint = Blueprint('users') - blueprint.rename('foo') + blueprint = Blueprint("users") + blueprint.rename("foo") statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) self.assertEqual('ALTER TABLE "users" RENAME TO "foo"', statements[0]) def test_adding_foreign_key(self): - blueprint = Blueprint('users') + blueprint = Blueprint("users") blueprint.create() - blueprint.string('foo').primary() - blueprint.string('order_id') - blueprint.foreign('order_id').references('id').on('orders') + blueprint.string("foo").primary() + blueprint.string("order_id") + blueprint.foreign("order_id").references("id").on("orders") statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) - expected = 'CREATE TABLE "users" ("foo" VARCHAR NOT NULL, "order_id" VARCHAR NOT NULL, ' \ - 'FOREIGN KEY("order_id") REFERENCES "orders"("id"), PRIMARY KEY ("foo"))' + expected = ( + 'CREATE TABLE "users" ("foo" VARCHAR NOT NULL, "order_id" VARCHAR NOT NULL, ' + 'FOREIGN KEY("order_id") REFERENCES "orders"("id"), PRIMARY KEY ("foo"))' + ) self.assertEqual(expected, statements[0]) def test_adding_unique_key(self): - blueprint = Blueprint('users') - blueprint.unique('foo', 'bar') + blueprint = Blueprint("users") + blueprint.unique("foo", "bar") statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) - self.assertEqual( - 'CREATE UNIQUE INDEX bar ON "users" ("foo")', - statements[0] - ) + self.assertEqual('CREATE UNIQUE INDEX bar ON "users" ("foo")', statements[0]) def test_adding_index(self): - blueprint = Blueprint('users') - blueprint.index('foo', 'bar') + blueprint = Blueprint("users") + blueprint.index("foo", "bar") statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) - self.assertEqual( - 'CREATE INDEX bar ON "users" ("foo")', - statements[0] - ) + self.assertEqual('CREATE INDEX bar ON "users" ("foo")', statements[0]) def test_adding_incrementing_id(self): - blueprint = Blueprint('users') - blueprint.increments('id') + blueprint = Blueprint("users") + blueprint.increments("id") statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) self.assertEqual( 'ALTER TABLE "users" ADD COLUMN "id" INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT', - statements[0] + statements[0], ) def test_adding_big_incrementing_id(self): - blueprint = Blueprint('users') - blueprint.big_increments('id') + blueprint = Blueprint("users") + blueprint.big_increments("id") statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) self.assertEqual( 'ALTER TABLE "users" ADD COLUMN "id" INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT', - statements[0] + statements[0], ) def test_adding_string(self): - blueprint = Blueprint('users') - blueprint.string('foo') + blueprint = Blueprint("users") + blueprint.string("foo") statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) self.assertEqual( - 'ALTER TABLE "users" ADD COLUMN "foo" VARCHAR NOT NULL', - statements[0] + 'ALTER TABLE "users" ADD COLUMN "foo" VARCHAR NOT NULL', statements[0] ) - blueprint = Blueprint('users') - blueprint.string('foo', 100) + blueprint = Blueprint("users") + blueprint.string("foo", 100) statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) self.assertEqual( - 'ALTER TABLE "users" ADD COLUMN "foo" VARCHAR NOT NULL', - statements[0] + 'ALTER TABLE "users" ADD COLUMN "foo" VARCHAR NOT NULL', statements[0] ) - blueprint = Blueprint('users') - blueprint.string('foo', 100).nullable().default('bar') + blueprint = Blueprint("users") + blueprint.string("foo", 100).nullable().default("bar") statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) self.assertEqual( 'ALTER TABLE "users" ADD COLUMN "foo" VARCHAR NULL DEFAULT \'bar\'', - statements[0] + statements[0], ) def test_adding_text(self): - blueprint = Blueprint('users') - blueprint.text('foo') + blueprint = Blueprint("users") + blueprint.text("foo") statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) self.assertEqual( - 'ALTER TABLE "users" ADD COLUMN "foo" TEXT NOT NULL', - statements[0] + 'ALTER TABLE "users" ADD COLUMN "foo" TEXT NOT NULL', statements[0] ) def test_adding_big_integer(self): - blueprint = Blueprint('users') - blueprint.big_integer('foo') + blueprint = Blueprint("users") + blueprint.big_integer("foo") statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) self.assertEqual( - 'ALTER TABLE "users" ADD COLUMN "foo" INTEGER NOT NULL', - statements[0] + 'ALTER TABLE "users" ADD COLUMN "foo" INTEGER NOT NULL', statements[0] ) - blueprint = Blueprint('users') - blueprint.big_integer('foo', True) + blueprint = Blueprint("users") + blueprint.big_integer("foo", True) statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) self.assertEqual( 'ALTER TABLE "users" ADD COLUMN "foo" INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT', - statements[0] + statements[0], ) def test_adding_integer(self): - blueprint = Blueprint('users') - blueprint.integer('foo') + blueprint = Blueprint("users") + blueprint.integer("foo") statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) self.assertEqual( - 'ALTER TABLE "users" ADD COLUMN "foo" INTEGER NOT NULL', - statements[0] + 'ALTER TABLE "users" ADD COLUMN "foo" INTEGER NOT NULL', statements[0] ) - blueprint = Blueprint('users') - blueprint.integer('foo', True) + blueprint = Blueprint("users") + blueprint.integer("foo", True) statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) self.assertEqual( 'ALTER TABLE "users" ADD COLUMN "foo" INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT', - statements[0] + statements[0], ) def test_adding_medium_integer(self): - blueprint = Blueprint('users') - blueprint.integer('foo') + blueprint = Blueprint("users") + blueprint.integer("foo") statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) self.assertEqual( - 'ALTER TABLE "users" ADD COLUMN "foo" INTEGER NOT NULL', - statements[0] + 'ALTER TABLE "users" ADD COLUMN "foo" INTEGER NOT NULL', statements[0] ) def test_adding_tiny_integer(self): - blueprint = Blueprint('users') - blueprint.integer('foo') + blueprint = Blueprint("users") + blueprint.integer("foo") statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) self.assertEqual( - 'ALTER TABLE "users" ADD COLUMN "foo" INTEGER NOT NULL', - statements[0] + 'ALTER TABLE "users" ADD COLUMN "foo" INTEGER NOT NULL', statements[0] ) def test_adding_small_integer(self): - blueprint = Blueprint('users') - blueprint.integer('foo') + blueprint = Blueprint("users") + blueprint.integer("foo") statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) self.assertEqual( - 'ALTER TABLE "users" ADD COLUMN "foo" INTEGER NOT NULL', - statements[0] + 'ALTER TABLE "users" ADD COLUMN "foo" INTEGER NOT NULL', statements[0] ) def test_adding_float(self): - blueprint = Blueprint('users') - blueprint.float('foo', 5, 2) + blueprint = Blueprint("users") + blueprint.float("foo", 5, 2) statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) self.assertEqual( - 'ALTER TABLE "users" ADD COLUMN "foo" FLOAT NOT NULL', - statements[0] + 'ALTER TABLE "users" ADD COLUMN "foo" FLOAT NOT NULL', statements[0] ) def test_adding_double(self): - blueprint = Blueprint('users') - blueprint.double('foo', 15, 8) + blueprint = Blueprint("users") + blueprint.double("foo", 15, 8) statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) self.assertEqual( - 'ALTER TABLE "users" ADD COLUMN "foo" FLOAT NOT NULL', - statements[0] + 'ALTER TABLE "users" ADD COLUMN "foo" FLOAT NOT NULL', statements[0] ) def test_adding_decimal(self): - blueprint = Blueprint('users') - blueprint.decimal('foo', 5, 2) + blueprint = Blueprint("users") + blueprint.decimal("foo", 5, 2) statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) self.assertEqual( - 'ALTER TABLE "users" ADD COLUMN "foo" NUMERIC NOT NULL', - statements[0] + 'ALTER TABLE "users" ADD COLUMN "foo" NUMERIC NOT NULL', statements[0] ) def test_adding_boolean(self): - blueprint = Blueprint('users') - blueprint.boolean('foo') + blueprint = Blueprint("users") + blueprint.boolean("foo") statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) self.assertEqual( - 'ALTER TABLE "users" ADD COLUMN "foo" TINYINT NOT NULL', - statements[0] + 'ALTER TABLE "users" ADD COLUMN "foo" TINYINT NOT NULL', statements[0] ) def test_adding_enum(self): - blueprint = Blueprint('users') - blueprint.enum('foo', ['bar', 'baz']) + blueprint = Blueprint("users") + blueprint.enum("foo", ["bar", "baz"]) statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) self.assertEqual( - 'ALTER TABLE "users" ADD COLUMN "foo" VARCHAR NOT NULL', - statements[0] + 'ALTER TABLE "users" ADD COLUMN "foo" VARCHAR NOT NULL', statements[0] ) def test_adding_date(self): - blueprint = Blueprint('users') - blueprint.date('foo') + blueprint = Blueprint("users") + blueprint.date("foo") statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) self.assertEqual( - 'ALTER TABLE "users" ADD COLUMN "foo" DATE NOT NULL', - statements[0] + 'ALTER TABLE "users" ADD COLUMN "foo" DATE NOT NULL', statements[0] ) def test_adding_datetime(self): - blueprint = Blueprint('users') - blueprint.datetime('foo') + blueprint = Blueprint("users") + blueprint.datetime("foo") statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) self.assertEqual( - 'ALTER TABLE "users" ADD COLUMN "foo" DATETIME NOT NULL', - statements[0] + 'ALTER TABLE "users" ADD COLUMN "foo" DATETIME NOT NULL', statements[0] ) def test_adding_time(self): - blueprint = Blueprint('users') - blueprint.time('foo') + blueprint = Blueprint("users") + blueprint.time("foo") statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) self.assertEqual( - 'ALTER TABLE "users" ADD COLUMN "foo" TIME NOT NULL', - statements[0] + 'ALTER TABLE "users" ADD COLUMN "foo" TIME NOT NULL', statements[0] ) def test_adding_timestamp(self): - blueprint = Blueprint('users') - blueprint.timestamp('foo') + blueprint = Blueprint("users") + blueprint.timestamp("foo") statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) self.assertEqual( - 'ALTER TABLE "users" ADD COLUMN "foo" DATETIME NOT NULL', - statements[0] + 'ALTER TABLE "users" ADD COLUMN "foo" DATETIME NOT NULL', statements[0] ) def test_adding_timestamp_with_current(self): - blueprint = Blueprint('users') - blueprint.timestamp('foo').use_current() + blueprint = Blueprint("users") + blueprint.timestamp("foo").use_current() statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) self.assertEqual( 'ALTER TABLE "users" ADD COLUMN "foo" DATETIME DEFAULT CURRENT_TIMESTAMP NOT NULL', - statements[0] + statements[0], ) def test_adding_timestamps(self): - blueprint = Blueprint('users') + blueprint = Blueprint("users") blueprint.timestamps() statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(2, len(statements)) expected = [ 'ALTER TABLE "users" ADD COLUMN "created_at" DATETIME DEFAULT CURRENT_TIMESTAMP NOT NULL', - 'ALTER TABLE "users" ADD COLUMN "updated_at" DATETIME DEFAULT CURRENT_TIMESTAMP NOT NULL' + 'ALTER TABLE "users" ADD COLUMN "updated_at" DATETIME DEFAULT CURRENT_TIMESTAMP NOT NULL', ] - self.assertEqual( - expected, - statements - ) + self.assertEqual(expected, statements) def test_adding_timestamps_not_current(self): - blueprint = Blueprint('users') + blueprint = Blueprint("users") blueprint.timestamps(use_current=False) statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(2, len(statements)) expected = [ 'ALTER TABLE "users" ADD COLUMN "created_at" DATETIME NOT NULL', - 'ALTER TABLE "users" ADD COLUMN "updated_at" DATETIME NOT NULL' + 'ALTER TABLE "users" ADD COLUMN "updated_at" DATETIME NOT NULL', ] - self.assertEqual( - expected, - statements - ) + self.assertEqual(expected, statements) def test_adding_binary(self): - blueprint = Blueprint('users') - blueprint.binary('foo') + blueprint = Blueprint("users") + blueprint.binary("foo") statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) self.assertEqual( - 'ALTER TABLE "users" ADD COLUMN "foo" BLOB NOT NULL', - statements[0] + 'ALTER TABLE "users" ADD COLUMN "foo" BLOB NOT NULL', statements[0] ) def test_adding_json(self): - blueprint = Blueprint('users') - blueprint.json('foo') + blueprint = Blueprint("users") + blueprint.json("foo") statements = blueprint.to_sql(self.get_connection(), self.get_grammar()) self.assertEqual(1, len(statements)) self.assertEqual( - 'ALTER TABLE "users" ADD COLUMN "foo" TEXT NOT NULL', - statements[0] + 'ALTER TABLE "users" ADD COLUMN "foo" TEXT NOT NULL', statements[0] ) def get_connection(self): diff --git a/tests/schema/integrations/__init__.py b/tests/schema/integrations/__init__.py index 3823997c..b8724bb2 100644 --- a/tests/schema/integrations/__init__.py +++ b/tests/schema/integrations/__init__.py @@ -1,13 +1,19 @@ # -*- coding: utf-8 -*- from orator import Model -from orator.orm import has_one, has_many, belongs_to, belongs_to_many, morph_to, morph_many +from orator.orm import ( + has_one, + has_many, + belongs_to, + belongs_to_many, + morph_to, + morph_many, +) from orator import QueryExpression from orator.dbal.exceptions import ColumnDoesNotExist class IntegrationTestCase(object): - @classmethod def setUpClass(cls): Model.set_connection_resolver(cls.get_connection_resolver()) @@ -22,173 +28,205 @@ def tearDownClass(cls): def setUp(self): with self.connection().transaction(): - self.schema().drop_if_exists('photos') - self.schema().drop_if_exists('posts') - self.schema().drop_if_exists('friends') - self.schema().drop_if_exists('users') - - with self.schema().create('users') as table: - table.increments('id') - table.string('email').unique() - table.integer('votes').default(0) + self.schema().drop_if_exists("photos") + self.schema().drop_if_exists("posts") + self.schema().drop_if_exists("friends") + self.schema().drop_if_exists("users") + + with self.schema().create("users") as table: + table.increments("id") + table.string("email").unique() + table.integer("votes").default(0) table.timestamps(use_current=True) - with self.schema().create('friends') as table: - table.integer('user_id', unsigned=True) - table.integer('friend_id', unsigned=True) - - table.foreign('user_id').references('id').on('users').on_delete('cascade') - table.foreign('friend_id').references('id').on('users').on_delete('cascade') - - with self.schema().create('posts') as table: - table.increments('id') - table.integer('user_id', unsigned=True) - table.string('name').unique() - table.enum('status', ['draft', 'published']).default('draft').nullable() - table.string('default').default(0) - table.string('tag').nullable().default('tag') + with self.schema().create("friends") as table: + table.integer("user_id", unsigned=True) + table.integer("friend_id", unsigned=True) + + table.foreign("user_id").references("id").on("users").on_delete( + "cascade" + ) + table.foreign("friend_id").references("id").on("users").on_delete( + "cascade" + ) + + with self.schema().create("posts") as table: + table.increments("id") + table.integer("user_id", unsigned=True) + table.string("name").unique() + table.enum("status", ["draft", "published"]).default("draft").nullable() + table.string("default").default(0) + table.string("tag").nullable().default("tag") table.timestamps(use_current=True) - table.foreign('user_id', 'users_foreign_key').references('id').on('users') + table.foreign("user_id", "users_foreign_key").references("id").on( + "users" + ) - with self.schema().create('photos') as table: - table.increments('id') - table.morphs('imageable') - table.string('name') + with self.schema().create("photos") as table: + table.increments("id") + table.morphs("imageable") + table.string("name") table.timestamps(use_current=True) for i in range(10): - user = User.create(email='user%d@foo.com' % (i + 1)) + user = User.create(email="user%d@foo.com" % (i + 1)) for j in range(10): - post = Post(name='User %d Post %d' % (user.id, j + 1)) + post = Post(name="User %d Post %d" % (user.id, j + 1)) user.posts().save(post) def tearDown(self): - self.schema().drop('photos') - self.schema().drop('posts') - self.schema().drop('friends') - self.schema().drop('users') + self.schema().drop("photos") + self.schema().drop("posts") + self.schema().drop("friends") + self.schema().drop("users") def test_foreign_keys_creation(self): - posts_foreign_keys = self.connection().get_schema_manager().list_table_foreign_keys('posts') - friends_foreign_keys = self.connection().get_schema_manager().list_table_foreign_keys('friends') + posts_foreign_keys = ( + self.connection().get_schema_manager().list_table_foreign_keys("posts") + ) + friends_foreign_keys = ( + self.connection().get_schema_manager().list_table_foreign_keys("friends") + ) - self.assertEqual('users_foreign_key', posts_foreign_keys[0].get_name()) + self.assertEqual("users_foreign_key", posts_foreign_keys[0].get_name()) self.assertEqual( - ['friends_friend_id_foreign', 'friends_user_id_foreign'], - sorted([f.get_name() for f in friends_foreign_keys]) + ["friends_friend_id_foreign", "friends_user_id_foreign"], + sorted([f.get_name() for f in friends_foreign_keys]), ) def test_add_columns(self): - with self.schema().table('posts') as table: - table.text('content').nullable() - table.integer('votes').default(QueryExpression(0)) + with self.schema().table("posts") as table: + table.text("content").nullable() + table.integer("votes").default(QueryExpression(0)) user = User.find(1) - post = user.posts().order_by('id', 'asc').first() + post = user.posts().order_by("id", "asc").first() - self.assertEqual('User 1 Post 1', post.name) + self.assertEqual("User 1 Post 1", post.name) self.assertEqual(0, post.votes) def test_remove_columns(self): - with self.schema().table('posts') as table: - table.drop_column('name') + with self.schema().table("posts") as table: + table.drop_column("name") - self.assertRaises(ColumnDoesNotExist, self.connection().get_column, 'posts', 'name') + self.assertRaises( + ColumnDoesNotExist, self.connection().get_column, "posts", "name" + ) user = User.find(1) - post = user.posts().order_by('id', 'asc').first() + post = user.posts().order_by("id", "asc").first() - self.assertFalse(hasattr(post, 'name')) + self.assertFalse(hasattr(post, "name")) def test_rename_columns(self): - old_foreign_keys = self.connection().get_schema_manager().list_table_foreign_keys('posts') + old_foreign_keys = ( + self.connection().get_schema_manager().list_table_foreign_keys("posts") + ) - with self.schema().table('posts') as table: - table.rename_column('name', 'title') + with self.schema().table("posts") as table: + table.rename_column("name", "title") - self.assertRaises(ColumnDoesNotExist, self.connection().get_column, 'posts', 'name') - self.assertIsNotNone(self.connection().get_column('posts', 'title')) + self.assertRaises( + ColumnDoesNotExist, self.connection().get_column, "posts", "name" + ) + self.assertIsNotNone(self.connection().get_column("posts", "title")) - foreign_keys = self.connection().get_schema_manager().list_table_foreign_keys('posts') + foreign_keys = ( + self.connection().get_schema_manager().list_table_foreign_keys("posts") + ) self.assertEqual(len(foreign_keys), len(old_foreign_keys)) user = User.find(1) - post = user.posts().order_by('id', 'asc').first() + post = user.posts().order_by("id", "asc").first() - self.assertEqual('User 1 Post 1', post.title) + self.assertEqual("User 1 Post 1", post.title) def test_rename_columns_with_index(self): - indexes = self.connection().get_schema_manager().list_table_indexes('users') - old_foreign_keys = self.connection().get_schema_manager().list_table_foreign_keys('posts') + indexes = self.connection().get_schema_manager().list_table_indexes("users") + old_foreign_keys = ( + self.connection().get_schema_manager().list_table_foreign_keys("posts") + ) - index = indexes['users_email_unique'] - self.assertEqual(['email'], index.get_columns()) + index = indexes["users_email_unique"] + self.assertEqual(["email"], index.get_columns()) self.assertTrue(index.is_unique()) - with self.schema().table('users') as table: - table.rename_column('email', 'email_address') + with self.schema().table("users") as table: + table.rename_column("email", "email_address") - self.assertRaises(ColumnDoesNotExist, self.connection().get_column, 'users', 'email') - self.assertIsNotNone(self.connection().get_column('users', 'email_address')) - foreign_keys = self.connection().get_schema_manager().list_table_foreign_keys('posts') + self.assertRaises( + ColumnDoesNotExist, self.connection().get_column, "users", "email" + ) + self.assertIsNotNone(self.connection().get_column("users", "email_address")) + foreign_keys = ( + self.connection().get_schema_manager().list_table_foreign_keys("posts") + ) self.assertEqual(len(foreign_keys), len(old_foreign_keys)) - indexes = self.connection().get_schema_manager().list_table_indexes('users') + indexes = self.connection().get_schema_manager().list_table_indexes("users") - index = indexes['users_email_unique'] - self.assertEqual('users_email_unique', index.get_name()) - self.assertEqual(['email_address'], index.get_columns()) + index = indexes["users_email_unique"] + self.assertEqual("users_email_unique", index.get_name()) + self.assertEqual(["email_address"], index.get_columns()) self.assertTrue(index.is_unique()) def test_rename_columns_with_foreign_keys(self): - old_foreign_keys = self.connection().get_schema_manager().list_table_foreign_keys('posts') + old_foreign_keys = ( + self.connection().get_schema_manager().list_table_foreign_keys("posts") + ) old_foreign_key = old_foreign_keys[0] - self.assertEqual(['user_id'], old_foreign_key.get_local_columns()) - self.assertEqual(['id'], old_foreign_key.get_foreign_columns()) - self.assertEqual('users', old_foreign_key.get_foreign_table_name()) + self.assertEqual(["user_id"], old_foreign_key.get_local_columns()) + self.assertEqual(["id"], old_foreign_key.get_foreign_columns()) + self.assertEqual("users", old_foreign_key.get_foreign_table_name()) - with self.schema().table('posts') as table: - table.rename_column('user_id', 'my_user_id') + with self.schema().table("posts") as table: + table.rename_column("user_id", "my_user_id") - self.assertRaises(ColumnDoesNotExist, self.connection().get_column, 'posts', 'user_id') - self.assertIsNotNone(self.connection().get_column('posts', 'my_user_id')) - foreign_keys = self.connection().get_schema_manager().list_table_foreign_keys('posts') + self.assertRaises( + ColumnDoesNotExist, self.connection().get_column, "posts", "user_id" + ) + self.assertIsNotNone(self.connection().get_column("posts", "my_user_id")) + foreign_keys = ( + self.connection().get_schema_manager().list_table_foreign_keys("posts") + ) self.assertEqual(len(foreign_keys), len(old_foreign_keys)) foreign_key = foreign_keys[0] - self.assertEqual(['my_user_id'], foreign_key.get_local_columns()) - self.assertEqual(['id'], foreign_key.get_foreign_columns()) - self.assertEqual('users', foreign_key.get_foreign_table_name()) + self.assertEqual(["my_user_id"], foreign_key.get_local_columns()) + self.assertEqual(["id"], foreign_key.get_foreign_columns()) + self.assertEqual("users", foreign_key.get_foreign_table_name()) def test_change_columns(self): - with self.schema().table('posts') as table: - table.integer('votes').default(0) + with self.schema().table("posts") as table: + table.integer("votes").default(0) - indexes = self.connection().get_schema_manager().list_table_indexes('posts') - old_foreign_keys = self.connection().get_schema_manager().list_table_foreign_keys('posts') + indexes = self.connection().get_schema_manager().list_table_indexes("posts") + old_foreign_keys = ( + self.connection().get_schema_manager().list_table_foreign_keys("posts") + ) - self.assertIn('posts_name_unique', indexes) - self.assertEqual(['name'], indexes['posts_name_unique'].get_columns()) - self.assertTrue(indexes['posts_name_unique'].is_unique()) + self.assertIn("posts_name_unique", indexes) + self.assertEqual(["name"], indexes["posts_name_unique"].get_columns()) + self.assertTrue(indexes["posts_name_unique"].is_unique()) post = Post.find(1) self.assertEqual(0, post.votes) - with self.schema().table('posts') as table: - table.string('name').nullable().change() - table.string('votes').default('0').change() - table.string('tag').default('new').change() + with self.schema().table("posts") as table: + table.string("name").nullable().change() + table.string("votes").default("0").change() + table.string("tag").default("new").change() - name_column = self.connection().get_column('posts', 'name') - votes_column = self.connection().get_column('posts', 'votes') - status_column = self.connection().get_column('posts', 'status') - tag_column = self.connection().get_column('posts', 'tag') + name_column = self.connection().get_column("posts", "name") + votes_column = self.connection().get_column("posts", "votes") + status_column = self.connection().get_column("posts", "status") + tag_column = self.connection().get_column("posts", "tag") self.assertFalse(name_column.get_notnull()) self.assertTrue(votes_column.get_notnull()) self.assertEqual("0", votes_column.get_default()) @@ -199,44 +237,50 @@ def test_change_columns(self): self.assertFalse(tag_column.get_notnull()) self.assertEqual("new", tag_column.get_default()) - foreign_keys = self.connection().get_schema_manager().list_table_foreign_keys('posts') + foreign_keys = ( + self.connection().get_schema_manager().list_table_foreign_keys("posts") + ) self.assertEqual(len(foreign_keys), len(old_foreign_keys)) - indexes = self.connection().get_schema_manager().list_table_indexes('posts') + indexes = self.connection().get_schema_manager().list_table_indexes("posts") - self.assertIn('posts_name_unique', indexes) - self.assertEqual(['name'], indexes['posts_name_unique'].get_columns()) - self.assertTrue(indexes['posts_name_unique'].is_unique()) + self.assertIn("posts_name_unique", indexes) + self.assertEqual(["name"], indexes["posts_name_unique"].get_columns()) + self.assertTrue(indexes["posts_name_unique"].is_unique()) post = Post.find(1) - self.assertEqual('0', post.votes) + self.assertEqual("0", post.votes) - with self.schema().table('users') as table: - table.big_integer('votes').change() + with self.schema().table("users") as table: + table.big_integer("votes").change() def test_cascading(self): - user = User.create(email='john@doe.com') - friend = User.create(email='jane@doe.com') - another_friend = User.create(email='another@doe.com') + user = User.create(email="john@doe.com") + friend = User.create(email="jane@doe.com") + another_friend = User.create(email="another@doe.com") user.friends().attach(friend) user.friends().attach(another_friend) user.delete() - self.assertEqual(0, user.get_connection_resolver().connection().table('friends').count()) + self.assertEqual( + 0, user.get_connection_resolver().connection().table("friends").count() + ) # Altering users table - with self.schema().table('users') as table: - table.string('email', 50).change() + with self.schema().table("users") as table: + table.string("email", 50).change() - user = User.create(email='john@doe.com') + user = User.create(email="john@doe.com") user.friends().attach(friend) user.friends().attach(another_friend) user.delete() - self.assertEqual(0, user.get_connection_resolver().connection().table('friends').count()) + self.assertEqual( + 0, user.get_connection_resolver().connection().table("friends").count() + ) def grammar(self): return self.connection().get_default_query_grammar() @@ -252,19 +296,19 @@ class User(Model): __guarded__ = [] - @belongs_to_many('friends', 'user_id', 'friend_id') + @belongs_to_many("friends", "user_id", "friend_id") def friends(self): return User - @has_many('user_id') + @has_many("user_id") def posts(self): return Post - @has_one('user_id') + @has_one("user_id") def post(self): return Post - @morph_many('imageable') + @morph_many("imageable") def photos(self): return Photo @@ -273,11 +317,11 @@ class Post(Model): __guarded__ = [] - @belongs_to('user_id') + @belongs_to("user_id") def user(self): return User - @morph_many('imageable') + @morph_many("imageable") def photos(self): return Photo diff --git a/tests/schema/integrations/test_mysql.py b/tests/schema/integrations/test_mysql.py index 1e1c93e1..d1e1b7f8 100644 --- a/tests/schema/integrations/test_mysql.py +++ b/tests/schema/integrations/test_mysql.py @@ -8,7 +8,6 @@ class SchemaBuilderMySQLIntegrationTestCase(IntegrationTestCase, OratorTestCase): - @classmethod def get_connection_resolver(cls): return DatabaseIntegrationConnectionResolver() @@ -22,28 +21,26 @@ def connection(self, name=None): if self._connection: return self._connection - ci = os.environ.get('CI', False) + ci = os.environ.get("CI", False) if ci: - database = 'orator_test' - user = 'root' - password = '' + database = "orator_test" + user = "root" + password = "" else: - database = 'orator_test' - user = 'orator' - password = 'orator' + database = "orator_test" + user = "orator" + password = "orator" self._connection = MySQLConnection( - MySQLConnector().connect({ - 'database': database, - 'user': user, - 'password': password - }) + MySQLConnector().connect( + {"database": database, "user": user, "password": password} + ) ) return self._connection def get_default_connection(self): - return 'default' + return "default" def set_default_connection(self, name): pass diff --git a/tests/schema/integrations/test_postgres.py b/tests/schema/integrations/test_postgres.py index a54688dc..fe141e35 100644 --- a/tests/schema/integrations/test_postgres.py +++ b/tests/schema/integrations/test_postgres.py @@ -8,7 +8,6 @@ class SchemaBuilderPostgresIntegrationTestCase(IntegrationTestCase, OratorTestCase): - @classmethod def get_connection_resolver(cls): return DatabaseIntegrationConnectionResolver() @@ -22,28 +21,26 @@ def connection(self, name=None): if self._connection: return self._connection - ci = os.environ.get('CI', False) + ci = os.environ.get("CI", False) if ci: - database = 'orator_test' - user = 'postgres' + database = "orator_test" + user = "postgres" password = None else: - database = 'orator_test' - user = 'orator' - password = 'orator' + database = "orator_test" + user = "orator" + password = "orator" self._connection = PostgresConnection( - PostgresConnector().connect({ - 'database': database, - 'user': user, - 'password': password - }) + PostgresConnector().connect( + {"database": database, "user": user, "password": password} + ) ) return self._connection def get_default_connection(self): - return 'default' + return "default" def set_default_connection(self, name): pass diff --git a/tests/schema/integrations/test_sqlite.py b/tests/schema/integrations/test_sqlite.py index 17ffab80..5bb2e6c8 100644 --- a/tests/schema/integrations/test_sqlite.py +++ b/tests/schema/integrations/test_sqlite.py @@ -8,7 +8,6 @@ class SchemaBuilderSQLiteIntegrationTestCase(IntegrationTestCase, OratorTestCase): - @classmethod def get_connection_resolver(cls): return DatabaseIntegrationConnectionResolver() @@ -17,78 +16,87 @@ def test_foreign_keys_creation(self): pass def test_rename_columns_with_foreign_keys(self): - super(SchemaBuilderSQLiteIntegrationTestCase, self).test_rename_columns_with_foreign_keys() + super( + SchemaBuilderSQLiteIntegrationTestCase, self + ).test_rename_columns_with_foreign_keys() - old_foreign_keys = self.connection().get_schema_manager().list_table_foreign_keys('friends') + old_foreign_keys = ( + self.connection().get_schema_manager().list_table_foreign_keys("friends") + ) - with self.schema().table('friends') as table: - table.rename_column('user_id', 'my_user_id') + with self.schema().table("friends") as table: + table.rename_column("user_id", "my_user_id") - foreign_keys = self.connection().get_schema_manager().list_table_foreign_keys('friends') + foreign_keys = ( + self.connection().get_schema_manager().list_table_foreign_keys("friends") + ) self.assertEqual(len(old_foreign_keys), len(foreign_keys)) class SchemaBuilderSQLiteIntegrationCascadingTestCase(OratorTestCase): - @classmethod def setUpClass(cls): - Model.set_connection_resolver(DatabaseIntegrationConnectionWithoutForeignKeysResolver()) + Model.set_connection_resolver( + DatabaseIntegrationConnectionWithoutForeignKeysResolver() + ) @classmethod def tearDownClass(cls): Model.unset_connection_resolver() def setUp(self): - with self.schema().create('users') as table: - table.increments('id') - table.string('email').unique() + with self.schema().create("users") as table: + table.increments("id") + table.string("email").unique() table.timestamps() - with self.schema().create('friends') as table: - table.integer('user_id') - table.integer('friend_id') + with self.schema().create("friends") as table: + table.integer("user_id") + table.integer("friend_id") - table.foreign('user_id').references('id').on('users').on_delete('cascade') - table.foreign('friend_id').references('id').on('users').on_delete('cascade') + table.foreign("user_id").references("id").on("users").on_delete("cascade") + table.foreign("friend_id").references("id").on("users").on_delete("cascade") - with self.schema().create('posts') as table: - table.increments('id') - table.integer('user_id') - table.string('name').unique() + with self.schema().create("posts") as table: + table.increments("id") + table.integer("user_id") + table.string("name").unique() table.timestamps() - table.foreign('user_id').references('id').on('users') + table.foreign("user_id").references("id").on("users") - with self.schema().create('photos') as table: - table.increments('id') - table.morphs('imageable') - table.string('name') + with self.schema().create("photos") as table: + table.increments("id") + table.morphs("imageable") + table.string("name") table.timestamps() for i in range(10): - user = User.create(email='user%d@foo.com' % (i + 1)) + user = User.create(email="user%d@foo.com" % (i + 1)) for j in range(10): - post = Post(name='User %d Post %d' % (user.id, j + 1)) + post = Post(name="User %d Post %d" % (user.id, j + 1)) user.posts().save(post) def tearDown(self): - self.schema().drop('photos') - self.schema().drop('posts') - self.schema().drop('friends') - self.schema().drop('users') + self.schema().drop("photos") + self.schema().drop("posts") + self.schema().drop("friends") + self.schema().drop("users") def test_cascading(self): - user = User.create(email='john@doe.com') - friend = User.create(email='jane@doe.com') - another_friend = User.create(email='another@doe.com') + user = User.create(email="john@doe.com") + friend = User.create(email="jane@doe.com") + another_friend = User.create(email="another@doe.com") user.friends().attach(friend) user.friends().attach(another_friend) user.delete() - self.assertEqual(2, user.get_connection_resolver().connection().table('friends').count()) + self.assertEqual( + 2, user.get_connection_resolver().connection().table("friends").count() + ) def connection(self): return Model.get_connection_resolver().connection() @@ -105,18 +113,22 @@ def connection(self, name=None): if self._connection: return self._connection - self._connection = SQLiteConnection(SQLiteConnector().connect({'database': ':memory:'})) + self._connection = SQLiteConnection( + SQLiteConnector().connect({"database": ":memory:"}) + ) return self._connection def get_default_connection(self): - return 'default' + return "default" def set_default_connection(self, name): pass -class DatabaseIntegrationConnectionWithoutForeignKeysResolver(DatabaseIntegrationConnectionResolver): +class DatabaseIntegrationConnectionWithoutForeignKeysResolver( + DatabaseIntegrationConnectionResolver +): _connection = None @@ -125,10 +137,7 @@ def connection(self, name=None): return self._connection self._connection = SQLiteConnection( - SQLiteConnector().connect( - {'database': ':memory:', - 'foreign_keys': False} - ) + SQLiteConnector().connect({"database": ":memory:", "foreign_keys": False}) ) return self._connection diff --git a/tests/schema/test_blueprint.py b/tests/schema/test_blueprint.py index 3eb6c737..94af46c0 100644 --- a/tests/schema/test_blueprint.py +++ b/tests/schema/test_blueprint.py @@ -8,38 +8,39 @@ class SchemaBuilderTestCase(OratorTestCase): - def tearDown(self): flexmock_teardown() def test_to_sql_runs_commands_from_blueprint(self): conn = flexmock(Connection(None)) - conn.should_receive('statement').once().with_args('foo') - conn.should_receive('statement').once().with_args('bar') + conn.should_receive("statement").once().with_args("foo") + conn.should_receive("statement").once().with_args("bar") grammar = flexmock(SchemaGrammar(conn)) - blueprint = flexmock(Blueprint('table')) - blueprint.should_receive('to_sql').once().with_args(conn, grammar).and_return(['foo', 'bar']) + blueprint = flexmock(Blueprint("table")) + blueprint.should_receive("to_sql").once().with_args(conn, grammar).and_return( + ["foo", "bar"] + ) blueprint.build(conn, grammar) def test_index_default_names(self): - blueprint = Blueprint('users') - blueprint.unique(['foo', 'bar']) + blueprint = Blueprint("users") + blueprint.unique(["foo", "bar"]) commands = blueprint.get_commands() - self.assertEqual('users_foo_bar_unique', commands[0].index) + self.assertEqual("users_foo_bar_unique", commands[0].index) - blueprint = Blueprint('users') - blueprint.index('foo') + blueprint = Blueprint("users") + blueprint.index("foo") commands = blueprint.get_commands() - self.assertEqual('users_foo_index', commands[0].index) + self.assertEqual("users_foo_index", commands[0].index) def test_drop_index_default_names(self): - blueprint = Blueprint('users') - blueprint.drop_unique(['foo', 'bar']) + blueprint = Blueprint("users") + blueprint.drop_unique(["foo", "bar"]) commands = blueprint.get_commands() - self.assertEqual('users_foo_bar_unique', commands[0].index) + self.assertEqual("users_foo_bar_unique", commands[0].index) - blueprint = Blueprint('users') - blueprint.drop_index(['foo']) + blueprint = Blueprint("users") + blueprint.drop_index(["foo"]) commands = blueprint.get_commands() - self.assertEqual('users_foo_index', commands[0].index) + self.assertEqual("users_foo_index", commands[0].index) diff --git a/tests/schema/test_builder.py b/tests/schema/test_builder.py index bf88f2cb..cd3e7c62 100644 --- a/tests/schema/test_builder.py +++ b/tests/schema/test_builder.py @@ -7,17 +7,18 @@ class SchemaBuilderTestCase(OratorTestCase): - def tearDown(self): flexmock_teardown() def test_has_table_correctly_calls_grammar(self): connection = flexmock(Connection(None)) grammar = flexmock() - connection.should_receive('get_schema_grammar').and_return(grammar) + connection.should_receive("get_schema_grammar").and_return(grammar) builder = SchemaBuilder(connection) - grammar.should_receive('compile_table_exists').once().and_return('sql') - connection.should_receive('get_table_prefix').once().and_return('prefix_') - connection.should_receive('select').once().with_args('sql', ['prefix_table']).and_return(['prefix_table']) + grammar.should_receive("compile_table_exists").once().and_return("sql") + connection.should_receive("get_table_prefix").once().and_return("prefix_") + connection.should_receive("select").once().with_args( + "sql", ["prefix_table"] + ).and_return(["prefix_table"]) - self.assertTrue(builder.has_table('table')) + self.assertTrue(builder.has_table("table")) diff --git a/tests/seeds/__init__.py b/tests/seeds/__init__.py index 633f8661..40a96afc 100644 --- a/tests/seeds/__init__.py +++ b/tests/seeds/__init__.py @@ -1,2 +1 @@ # -*- coding: utf-8 -*- - diff --git a/tests/seeds/test_seeder.py b/tests/seeds/test_seeder.py index 1a9a6bbd..fdc7ba67 100644 --- a/tests/seeds/test_seeder.py +++ b/tests/seeds/test_seeder.py @@ -10,33 +10,32 @@ class SeederTestCase(OratorTestCase): - def tearDown(self): super(SeederTestCase, self).tearDown() flexmock_teardown() def test_call_resolve_class_and_calls_run(self): resolver_mock = flexmock(DatabaseManager) - resolver_mock.should_receive('connection').and_return({}) + resolver_mock.should_receive("connection").and_return({}) resolver = flexmock(DatabaseManager({})) connection = flexmock(Connection(None)) - resolver.should_receive('connection').with_args(None).and_return(connection) + resolver.should_receive("connection").with_args(None).and_return(connection) seeder = Seeder(resolver) - command = flexmock(Command('foo')) - command.should_receive('line').once() + command = flexmock(Command("foo")) + command.should_receive("line").once() seeder.set_command(command) child = flexmock() - child.__name__ = 'foo' - child.should_receive('set_command').once().with_args(command) - child.should_receive('set_connection_resolver').once().with_args(resolver) - child.should_receive('run').once() + child.__name__ = "foo" + child.should_receive("set_command").once().with_args(command) + child.should_receive("set_connection_resolver").once().with_args(resolver) + child.should_receive("run").once() seeder.call(child) class Command(BaseCommand): - resolver = 'bar' + resolver = "bar" def get_output(self): - return 'foo' + return "foo" diff --git a/tests/support/test_collection.py b/tests/support/test_collection.py index 171c2ac7..376cda08 100644 --- a/tests/support/test_collection.py +++ b/tests/support/test_collection.py @@ -5,33 +5,32 @@ class CollectionTestCase(OratorTestCase): - def test_first_returns_first_item_in_collection(self): - c = Collection(['foo', 'bar']) + c = Collection(["foo", "bar"]) - self.assertEqual('foo', c.first()) + self.assertEqual("foo", c.first()) def test_last_returns_last_item_in_collection(self): - c = Collection(['foo', 'bar']) + c = Collection(["foo", "bar"]) - self.assertEqual('bar', c.last()) + self.assertEqual("bar", c.last()) def test_pop_removes_and_returns_last_item_or_specified_index(self): - c = Collection(['foo', 'bar']) + c = Collection(["foo", "bar"]) - self.assertEqual('bar', c.pop()) - self.assertEqual('foo', c.last()) + self.assertEqual("bar", c.pop()) + self.assertEqual("foo", c.last()) - c = Collection(['foo', 'bar']) + c = Collection(["foo", "bar"]) - self.assertEqual('foo', c.pop(0)) - self.assertEqual('bar', c.first()) + self.assertEqual("foo", c.pop(0)) + self.assertEqual("bar", c.first()) def test_shift_removes_and_returns_first_item(self): - c = Collection(['foo', 'bar']) + c = Collection(["foo", "bar"]) - self.assertEqual('foo', c.shift()) - self.assertEqual('bar', c.first()) + self.assertEqual("foo", c.shift()) + self.assertEqual("bar", c.first()) def test_empty_collection_is_empty(self): c = Collection() @@ -41,8 +40,8 @@ def test_empty_collection_is_empty(self): self.assertTrue(c2.is_empty()) def test_collection_is_constructed(self): - c = Collection('foo') - self.assertEqual(['foo'], c.all()) + c = Collection("foo") + self.assertEqual(["foo"], c.all()) c = Collection(2) self.assertEqual([2], c.all()) @@ -57,25 +56,25 @@ def test_collection_is_constructed(self): self.assertEqual([], c.all()) def test_offset_access(self): - c = Collection(['foo', 'bar']) - self.assertEqual('bar', c[1]) + c = Collection(["foo", "bar"]) + self.assertEqual("bar", c[1]) - c[1] = 'baz' - self.assertEqual('baz', c[1]) + c[1] = "baz" + self.assertEqual("baz", c[1]) del c[0] - self.assertEqual('baz', c[0]) + self.assertEqual("baz", c[0]) def test_forget(self): - c = Collection(['foo', 'bar', 'boom']) + c = Collection(["foo", "bar", "boom"]) c.forget(0) - self.assertEqual('bar', c[0]) + self.assertEqual("bar", c[0]) c.forget(0, 1) self.assertTrue(c.is_empty()) def test_get_avg_items_from_collection(self): - c = Collection([{'foo': 10}, {'foo': 20}]) - self.assertEqual(15, c.avg('foo')) + c = Collection([{"foo": 10}, {"foo": 20}]) + self.assertEqual(15, c.avg("foo")) c = Collection([1, 2, 3, 4, 5]) self.assertEqual(3, c.avg()) @@ -103,31 +102,31 @@ def test_contains(self): self.assertFalse(c.contains(lambda x: x > 5)) self.assertIn(3, c) - c = Collection([{'v': 1}, {'v': 3}, {'v': 5}]) - self.assertTrue(c.contains('v', 1)) - self.assertFalse(c.contains('v', 2)) + c = Collection([{"v": 1}, {"v": 3}, {"v": 5}]) + self.assertTrue(c.contains("v", 1)) + self.assertFalse(c.contains("v", 2)) - obj1 = type('lamdbaobject', (object,), {})() + obj1 = type("lamdbaobject", (object,), {})() obj1.v = 1 - obj2 = type('lamdbaobject', (object,), {})() + obj2 = type("lamdbaobject", (object,), {})() obj2.v = 3 - obj3 = type('lamdbaobject', (object,), {})() + obj3 = type("lamdbaobject", (object,), {})() obj3.v = 5 - c = Collection([{'v': 1}, {'v': 3}, {'v': 5}]) - self.assertTrue(c.contains('v', 1)) - self.assertFalse(c.contains('v', 2)) + c = Collection([{"v": 1}, {"v": 3}, {"v": 5}]) + self.assertTrue(c.contains("v", 1)) + self.assertFalse(c.contains("v", 2)) def test_countable(self): - c = Collection(['foo', 'bar']) + c = Collection(["foo", "bar"]) self.assertEqual(2, c.count()) self.assertEqual(2, len(c)) def test_diff(self): - c = Collection(['foo', 'bar']) - self.assertEqual(['foo'], c.diff(Collection(['bar', 'baz'])).all()) + c = Collection(["foo", "bar"]) + self.assertEqual(["foo"], c.diff(Collection(["bar", "baz"])).all()) def test_each(self): - original = ['foo', 'bar', 'baz'] + original = ["foo", "bar", "baz"] c = Collection(original) result = [] @@ -141,35 +140,39 @@ def test_every(self): self.assertEqual([2, 4, 6], c.every(2, 1).all()) def test_filter(self): - c = Collection([{'id': 1, 'name': 'hello'}, {'id': 2, 'name': 'world'}]) - self.assertEqual([{'id': 2, 'name': 'world'}], c.filter(lambda item: item['id'] == 2).all()) + c = Collection([{"id": 1, "name": "hello"}, {"id": 2, "name": "world"}]) + self.assertEqual( + [{"id": 2, "name": "world"}], c.filter(lambda item: item["id"] == 2).all() + ) - c = Collection(['', 'hello', '', 'world']) - self.assertEqual(['hello', 'world'], c.filter().all()) + c = Collection(["", "hello", "", "world"]) + self.assertEqual(["hello", "world"], c.filter().all()) def test_where(self): - c = Collection([{'v': 1}, {'v': 3}, {'v': 2}, {'v': 3}, {'v': 4}]) - self.assertEqual([{'v': 3}, {'v': 3}], c.where('v', 3).all()) + c = Collection([{"v": 1}, {"v": 3}, {"v": 2}, {"v": 3}, {"v": 4}]) + self.assertEqual([{"v": 3}, {"v": 3}], c.where("v", 3).all()) def test_implode(self): - obj1 = type('lamdbaobject', (object,), {})() - obj1.name = 'john' - obj1.email = 'foo' - c = Collection([{'name': 'john', 'email': 'foo'}, {'name': 'jane', 'email': 'bar'}]) - self.assertEqual('foobar', c.implode('email')) - self.assertEqual('foo,bar', c.implode('email', ',')) + obj1 = type("lamdbaobject", (object,), {})() + obj1.name = "john" + obj1.email = "foo" + c = Collection( + [{"name": "john", "email": "foo"}, {"name": "jane", "email": "bar"}] + ) + self.assertEqual("foobar", c.implode("email")) + self.assertEqual("foo,bar", c.implode("email", ",")) - c = Collection(['foo', 'bar']) - self.assertEqual('foobar', c.implode('')) - self.assertEqual('foo,bar', c.implode(',')) + c = Collection(["foo", "bar"]) + self.assertEqual("foobar", c.implode("")) + self.assertEqual("foo,bar", c.implode(",")) def test_lists(self): - obj1 = type('lamdbaobject', (object,), {})() - obj1.name = 'john' - obj1.email = 'foo' - c = Collection([obj1, {'name': 'jane', 'email': 'bar'}]) - self.assertEqual({'john': 'foo', 'jane': 'bar'}, c.lists('email', 'name')) - self.assertEqual(['foo', 'bar'], c.pluck('email').all()) + obj1 = type("lamdbaobject", (object,), {})() + obj1.name = "john" + obj1.email = "foo" + c = Collection([obj1, {"name": "jane", "email": "bar"}]) + self.assertEqual({"john": "foo", "jane": "bar"}, c.lists("email", "name")) + self.assertEqual(["foo", "bar"], c.pluck("email").all()) def test_map(self): c = Collection([1, 2, 3, 4, 5]) @@ -247,12 +250,9 @@ def test_without(self): self.assertEqual([1, 2, 3, 4, 5], c.all()) def test_flatten(self): - c = Collection({'foo': [5, 6], 'bar': 7, 'baz': {'boom': [1, 2, 3, 4]}}) + c = Collection({"foo": [5, 6], "bar": 7, "baz": {"boom": [1, 2, 3, 4]}}) - self.assertEqual( - [1, 2, 3, 4, 5, 6, 7], - c.flatten().sort().all() - ) + self.assertEqual([1, 2, 3, 4, 5, 6, 7], c.flatten().sort().all()) c = Collection([1, [2, 3], 4]) self.assertEqual([1, 2, 3, 4], c.flatten().all()) diff --git a/tests/support/test_fluent.py b/tests/support/test_fluent.py index cf111cc6..4ae07f4d 100644 --- a/tests/support/test_fluent.py +++ b/tests/support/test_fluent.py @@ -5,33 +5,34 @@ class FluentTestCase(OratorTestCase): - def test_get_method_return_attributes(self): - fluent = Fluent(name='john') + fluent = Fluent(name="john") - self.assertEqual('john', fluent.get('name')) - self.assertEqual('default', fluent.get('foo', 'default')) - self.assertEqual('john', fluent.name) + self.assertEqual("john", fluent.get("name")) + self.assertEqual("default", fluent.get("foo", "default")) + self.assertEqual("john", fluent.name) self.assertEqual(None, fluent.foo) def test_set_attributes(self): fluent = Fluent() - fluent.name = 'john' + fluent.name = "john" fluent.developer() fluent.age(25) - self.assertEqual('john', fluent.name) + self.assertEqual("john", fluent.name) self.assertTrue(fluent.developer) self.assertEqual(25, fluent.age) - self.assertEqual({'name': 'john', 'developer': True, 'age': 25}, fluent.get_attributes()) + self.assertEqual( + {"name": "john", "developer": True, "age": 25}, fluent.get_attributes() + ) def test_chained_attributes(self): fluent = Fluent() fluent.unsigned = False - fluent.integer('status').unsigned() + fluent.integer("status").unsigned() - self.assertEqual('status', fluent.integer) + self.assertEqual("status", fluent.integer) self.assertTrue(fluent.unsigned) diff --git a/tests/test_database_manager.py b/tests/test_database_manager.py index b4bffc6a..1d783033 100644 --- a/tests/test_database_manager.py +++ b/tests/test_database_manager.py @@ -8,17 +8,14 @@ class ConnectionTestCase(OratorTestCase): - def test_connection_method_create_a_new_connection_if_needed(self): manager = self._get_manager() - manager.table('users') + manager.table("users") - manager._make_connection.assert_called_once_with( - 'sqlite' - ) + manager._make_connection.assert_called_once_with("sqlite") manager._make_connection.reset_mock() - manager.table('users') + manager.table("users") self.assertFalse(manager._make_connection.called) def test_manager_uses_factory_to_create_connections(self): @@ -28,37 +25,30 @@ def test_manager_uses_factory_to_create_connections(self): manager.connection() manager._factory.make.assert_called_with( - { - 'name': 'sqlite', - 'driver': 'sqlite', - 'database': ':memory:' - }, 'sqlite' + {"name": "sqlite", "driver": "sqlite", "database": ":memory:"}, "sqlite" ) manager._factory.make = original_make def test_connection_can_select_connections(self): manager = self._get_manager() - self.assertEqual(manager.connection(), manager.connection('sqlite')) - self.assertNotEqual(manager.connection('sqlite'), manager.connection('sqlite2')) + self.assertEqual(manager.connection(), manager.connection("sqlite")) + self.assertNotEqual(manager.connection("sqlite"), manager.connection("sqlite2")) def test_dynamic_attribute_gets_connection_attribute(self): manager = self._get_manager() - manager.statement('CREATE TABLE users') + manager.statement("CREATE TABLE users") - manager.get_connections()['sqlite'].statement.assert_called_once_with( - 'CREATE TABLE users' + manager.get_connections()["sqlite"].statement.assert_called_once_with( + "CREATE TABLE users" ) def test_default_database_with_one_database(self): - manager = MockManager({ - 'sqlite': { - 'driver': 'sqlite', - 'database': ':memory:' - } - }).prepare_mock() + manager = MockManager( + {"sqlite": {"driver": "sqlite", "database": ":memory:"}} + ).prepare_mock() - self.assertEqual('sqlite', manager.get_default_connection()) + self.assertEqual("sqlite", manager.get_default_connection()) def test_reconnect(self): manager = self._get_real_manager() @@ -73,31 +63,26 @@ def test_set_default_connection_with_none_should_not_overwrite(self): manager.set_default_connection(None) - self.assertEqual('sqlite', manager.get_default_connection()) + self.assertEqual("sqlite", manager.get_default_connection()) def _get_manager(self): - manager = MockManager({ - 'default': 'sqlite', - 'sqlite': { - 'driver': 'sqlite', - 'database': ':memory:' - }, - 'sqlite2': { - 'driver': 'sqlite', - 'database': ':memory:' + manager = MockManager( + { + "default": "sqlite", + "sqlite": {"driver": "sqlite", "database": ":memory:"}, + "sqlite2": {"driver": "sqlite", "database": ":memory:"}, } - }).prepare_mock() + ).prepare_mock() return manager def _get_real_manager(self): - manager = DatabaseManager({ - 'default': 'sqlite', - 'sqlite': { - 'driver': 'sqlite', - 'database': ':memory:' + manager = DatabaseManager( + { + "default": "sqlite", + "sqlite": {"driver": "sqlite", "database": ":memory:"}, } - }) + ) return manager diff --git a/tests/utils.py b/tests/utils.py index 8f3e09c8..91b0c400 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -10,7 +10,6 @@ class MockConnection(ConnectionInterface): - def __init__(self, name=None): if name: self.get_name = lambda: name @@ -25,12 +24,12 @@ def prepare_mock(self, name=None): self.update = mock.MagicMock() self.delete = mock.MagicMock() self.statement = mock.MagicMock() + self.select_many = mock.MagicMock() return self class MockProcessor(QueryProcessor): - def prepare_mock(self): self.process_select = mock.MagicMock() self.process_insert_get_id = mock.MagicMock() @@ -39,7 +38,6 @@ def prepare_mock(self): class MockManager(DatabaseManager): - def prepare_mock(self): self._make_connection = mock.MagicMock( side_effect=lambda name: MockConnection(name).prepare_mock() @@ -49,7 +47,6 @@ def prepare_mock(self): class MockFactory(ConnectionFactory): - def prepare_mock(self): self.make = mock.MagicMock(return_value=MockConnection().prepare_mock()) @@ -57,18 +54,16 @@ def prepare_mock(self): class MockQueryBuilder(QueryBuilder): - def prepare_mock(self): - self.from__ = 'foo_table' + self.from__ = "foo_table" return self class MockModel(Model): - def prepare_mock(self): - self.get_key_name = mock.MagicMock(return_value='foo') - self.get_table = mock.MagicMock(return_value='foo_table') - self.get_qualified_key_name = mock.MagicMock(return_value='foo_table.foo') + self.get_key_name = mock.MagicMock(return_value="foo") + self.get_table = mock.MagicMock(return_value="foo_table") + self.get_qualified_key_name = mock.MagicMock(return_value="foo_table.foo") return self diff --git a/tox.ini b/tox.ini index 16805b78..61907978 100644 --- a/tox.ini +++ b/tox.ini @@ -1,15 +1,42 @@ [tox] -envlist = py{27,35}-{pymysql,mysqlclient} +isolated_build = true +envlist = py{27,35,36,37}-{pymysql,mysqlclient} [testenv] -deps = - -rtests-requirements.txt - pymysql: pymysql - mysqlclient: mysqlclient -commands = py.test tests/ -sq - -[testenv:flake8] -basepython=python -deps=flake8 -commands= - flake8 orator +whitelist_externals = poetry + +[testenv:py27-pymysql] +commands: + poetry run pip install -U pip + poetry install -E mysql-python -E pgsql -v + poetry run pytest tests/ -sq + +[testenv:py27-mysqlclient] +commands: + poetry run pip install -U pip + poetry install -E mysql -E pgsql -v + poetry run pytest tests/ -sq + +[testenv:py35-pymysql] +commands: + poetry run pip install -U pip + poetry install -E mysql-python -E pgsql -v + poetry run pytest tests/ -sq + +[testenv:py35-mysqlclient] +commands: + poetry run pip install -U pip + poetry install -E mysql -E pgsql -v + poetry run pytest tests/ -sq + +[testenv:py36-pymysql] +commands: + poetry run pip install -U pip + poetry install -E mysql-python -E pgsql -v + poetry run pytest tests/ -sq + +[testenv:py36-mysqlclient] +commands: + poetry run pip install -U pip + poetry install -E mysql -E pgsql -v + poetry run pytest tests/ -sq