From: Artem Silenkov Date: Tue, 19 Jan 2016 16:32:40 +0000 (+0300) Subject: python-sqlalchemy-utils added for Mitaka 9.0 X-Git-Tag: mos-9.0^0 X-Git-Url: https://review.fuel-infra.org/gitweb?a=commitdiff_plain;h=a4b362ee9f1d9bec14f8760dae2e77bb3ef53256;p=packages%2Ftrusty%2Fpython-sqlalchemy-utils.git python-sqlalchemy-utils added for Mitaka 9.0 Change-Id: I3931ab16d331e740107f5bdb18bf4e687fbeaf05 --- diff --git a/debian/changelog b/debian/changelog new file mode 100644 index 0000000..30a647f --- /dev/null +++ b/debian/changelog @@ -0,0 +1,19 @@ +python-sqlalchemy-utils (0.30.12-2~u14.04+mos1) mos9.0; urgency=medium + + * Sources are from: + http://ftp.acc.umu.se/debian/pool/main/p/python-sqlalchemy-utils/ + * Requirements are updated for Mitaka 9.0 + + -- Artem Silenkov Tue, 19 Jan 2016 19:24:48 +0300 + +python-sqlalchemy-utils (0.30.12-2) unstable; urgency=medium + + * Added reproducibility patch from Chris Lamb (Closes: #799206). + + -- Thomas Goirand Fri, 13 Nov 2015 09:54:23 +0000 + +python-sqlalchemy-utils (0.30.12-1) unstable; urgency=medium + + * Initial release. (Closes: #798561) + + -- Thomas Goirand Fri, 27 Mar 2015 16:18:57 +0100 diff --git a/debian/compat b/debian/compat new file mode 100644 index 0000000..ec63514 --- /dev/null +++ b/debian/compat @@ -0,0 +1 @@ +9 diff --git a/debian/control b/debian/control new file mode 100644 index 0000000..8927e64 --- /dev/null +++ b/debian/control @@ -0,0 +1,53 @@ +Source: python-sqlalchemy-utils +Section: python +Priority: optional +Maintainer: PKG OpenStack +Uploaders: Thomas Goirand +Build-Depends: debhelper (>= 9), + dh-python, + python-all, + python-setuptools, + python-sphinx, + python3-all, + python3-setuptools, +Build-Depends-Indep: python-six, + python-sqlalchemy, + python3-six, + python3-sqlalchemy, +Standards-Version: 3.9.6 +Vcs-Browser: http://anonscm.debian.org/gitweb/?p=openstack/python-sqlalchemy-utils.git +Vcs-Git: git://anonscm.debian.org/openstack/python-sqlalchemy-utils.git +Homepage: https://github.com/kvesteri/sqlalchemy-utils + +Package: python-sqlalchemy-utils +Architecture: all +Depends: ${misc:Depends}, ${python:Depends} +Suggests: python-sqlalchemy-utils-doc +Description: various utility functions for SQLAlchemy - Python 2.x + Various utility functions and custom data types for SQLAlchemy. + . + SQLAlchemy is an SQL database abstraction library for Python. + . + This package contains the Python 2.x module. + +Package: python3-sqlalchemy-utils +Architecture: all +Depends: ${misc:Depends}, ${python3:Depends} +Suggests: python-sqlalchemy-utils-doc +Description: various utility functions for SQLAlchemy - Python 3.x + Various utility functions and custom data types for SQLAlchemy. + . + SQLAlchemy is an SQL database abstraction library for Python. + . + This package contains the Python 3.x module. + +Package: python-sqlalchemy-utils-doc +Section: doc +Architecture: all +Depends: ${misc:Depends}, ${sphinxdoc:Depends} +Description: various utility functions for SQLAlchemy - doc + Various utility functions and custom data types for SQLAlchemy. + . + SQLAlchemy is an SQL database abstraction library for Python. + . + This package contains the documentation. diff --git a/debian/copyright b/debian/copyright new file mode 100644 index 0000000..40a165a --- /dev/null +++ b/debian/copyright @@ -0,0 +1,37 @@ +Format: http://www.debian.org/doc/packaging-manuals/copyright-format/1.0/ +Upstream-Name: nova +Source: git://github.com/kvesteri/sqlalchemy-utils.git + +Files: debian/* +Copyright: (c) 2015, Thomas Goirand +License: BSD-3-clause + +Files: * +Copyright: (c) 2012, Konsta Vesterinen +License: BSD-3-clause + +License: BSD-3-clause + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are met: + . + * Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + . + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + . + * The names of the contributors may not be used to endorse or promote + products derived from this software without specific prior written + permission. + . + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER BE LIABLE FOR ANY DIRECT, + INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY + OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, + EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/debian/gbp.conf b/debian/gbp.conf new file mode 100644 index 0000000..10f9500 --- /dev/null +++ b/debian/gbp.conf @@ -0,0 +1,9 @@ +[DEFAULT] +upstream-branch = master +debian-branch = debian/unstable +upstream-tag = %(version)s +compression = xz + +[buildpackage] +export-dir = ../build-area/ + diff --git a/debian/patches/reproducible_build.patch b/debian/patches/reproducible_build.patch new file mode 100644 index 0000000..fb748a8 --- /dev/null +++ b/debian/patches/reproducible_build.patch @@ -0,0 +1,11 @@ +--- python-sqlalchemy-utils-0.30.12.orig/setup.py ++++ python-sqlalchemy-utils-0.30.12/setup.py +@@ -57,7 +57,7 @@ extras_require = { + + # Add all optional dependencies to testing requirements. + test_all = [] +-for name, requirements in extras_require.items(): ++for name, requirements in sorted(extras_require.items()): + test_all += requirements + extras_require['test_all'] = test_all + diff --git a/debian/patches/series b/debian/patches/series new file mode 100644 index 0000000..b2026fe --- /dev/null +++ b/debian/patches/series @@ -0,0 +1 @@ +reproducible_build.patch diff --git a/debian/python-sqlalchemy-utils-doc.doc-base b/debian/python-sqlalchemy-utils-doc.doc-base new file mode 100644 index 0000000..3857e0a --- /dev/null +++ b/debian/python-sqlalchemy-utils-doc.doc-base @@ -0,0 +1,9 @@ +Document: python-sqlalchemy-utils-doc +Title: sqlalchemy-utils Documentation +Author: N/A +Abstract: Sphinx documentation for sqlalchemy-utils +Section: Programming/Python + +Format: HTML +Index: /usr/share/doc/python-sqlalchemy-utils-doc/html/index.html +Files: /usr/share/doc/python-sqlalchemy-utils-doc/html/* diff --git a/debian/rules b/debian/rules new file mode 100755 index 0000000..2660d23 --- /dev/null +++ b/debian/rules @@ -0,0 +1,49 @@ +#!/usr/bin/make -f + +PYTHONS:=$(shell pyversions -vr) +PYTHON3S:=$(shell py3versions -vr) + +UPSTREAM_GIT = git://github.com/kvesteri/sqlalchemy-utils.git +-include /usr/share/openstack-pkg-tools/pkgos.make + +%: + dh $@ --buildsystem=python_distutils --with python2,python3,sphinxdoc + +override_dh_install: + set -e ; for pyvers in $(PYTHONS); do \ + python$$pyvers setup.py install --install-layout=deb \ + --root $(CURDIR)/debian/python-sqlalchemy-utils; \ + done + set -e ; for pyvers in $(PYTHON3S); do \ + python$$pyvers setup.py install --install-layout=deb \ + --root $(CURDIR)/debian/python3-sqlalchemy-utils; \ + done + rm -rf $(CURDIR)/debian/python*-sqlalchemy-utils/usr/lib/python*/dist-packages/*.pth + +override_dh_auto_test: +ifeq (,$(findstring nocheck, $(DEB_BUILD_OPTIONS))) + set -e ; for pyvers in $(PYTHONS) ; do \ + PYTHONPATH=.:./src python$$pyvers setup.py test ; \ + done +endif + +override_dh_sphinxdoc: + PYTHONPATH=. sphinx-build -b html docs debian/python-sqlalchemy-utils-doc/usr/share/doc/python-sqlalchemy-utils-doc/html + dh_sphinxdoc -O--buildsystem=python_distutils + + +override_dh_clean: + dh_clean -O--buildsystem=python_distutils + rm -rf build + + +# Commands not to run +override_dh_installcatalogs: +override_dh_installemacsen override_dh_installifupdown: +override_dh_installinfo override_dh_installmenu override_dh_installmime: +override_dh_installmodules override_dh_installlogcheck: +override_dh_installpam override_dh_installppp override_dh_installudev override_dh_installwm: +override_dh_installxfonts override_dh_gconf override_dh_icons override_dh_perl override_dh_usrlocal: +override_dh_installcron override_dh_installdebconf: +override_dh_installlogrotate override_dh_installgsettings: + \ No newline at end of file diff --git a/debian/source/format b/debian/source/format new file mode 100644 index 0000000..163aaf8 --- /dev/null +++ b/debian/source/format @@ -0,0 +1 @@ +3.0 (quilt) diff --git a/debian/watch b/debian/watch new file mode 100644 index 0000000..a0a460f --- /dev/null +++ b/debian/watch @@ -0,0 +1,3 @@ +version=3 +opts="uversionmangle=s/\.(b|rc)/~$1/" \ +https://github.com/kvesteri/sqlalchemy-utils/tags .*/(\d[\d\.]+)\.tar\.gz diff --git a/python-sqlalchemy-utils/.gitignore b/python-sqlalchemy-utils/.gitignore new file mode 100644 index 0000000..95825ec --- /dev/null +++ b/python-sqlalchemy-utils/.gitignore @@ -0,0 +1,44 @@ +*.py[cod] + +# C extensions +*.so + +# Packages +*.egg +*.egg-info +dist +build +eggs +parts +bin +var +sdist +develop-eggs +.installed.cfg +lib +lib64 +docs/_build + +# Installer logs +pip-log.txt + +# Unit test / coverage reports +.coverage +.tox +nosetests.xml + +# Translations +*.mo + +# Mr Developer +.mr.developer.cfg +.project +.pydevproject + +# vim +[._]*.s[a-w][a-z] +[._]s[a-w][a-z] +*.un~ +Session.vim +.netrwhist +*~ diff --git a/python-sqlalchemy-utils/.isort.cfg b/python-sqlalchemy-utils/.isort.cfg new file mode 100644 index 0000000..52591e4 --- /dev/null +++ b/python-sqlalchemy-utils/.isort.cfg @@ -0,0 +1,6 @@ +[settings] +known_first_party=sqlalchemy_utils,tests +line_length=79 +multi_line_output=3 +not_skip=__init__.py +order_by_type=false diff --git a/python-sqlalchemy-utils/.travis.yml b/python-sqlalchemy-utils/.travis.yml new file mode 100644 index 0000000..a307723 --- /dev/null +++ b/python-sqlalchemy-utils/.travis.yml @@ -0,0 +1,38 @@ +before_script: + - psql -c 'create database sqlalchemy_utils_test;' -U postgres + - psql -c 'create extension hstore;' -U postgres -d sqlalchemy_utils_test + - mysql -e 'create database sqlalchemy_utils_test;' + +before_install: + - sudo /etc/init.d/postgresql stop + - sudo apt-get update + - sudo apt-get purge postgresql-9.1 postgresql-9.2 postgresql-9.3 + - sudo apt-get install postgresql-client-common postgresql-9.4 postgresql-contrib-9.4 + - sudo apt-get install pidentd + - sudo chmod 777 /etc/postgresql/9.4/main/pg_hba.conf + - sudo echo "local all postgres trust" > /etc/postgresql/9.4/main/pg_hba.conf + - sudo echo "local all all trust" >> /etc/postgresql/9.4/main/pg_hba.conf + - sudo echo "host all all 127.0.0.1/32 trust" >> /etc/postgresql/9.4/main/pg_hba.conf + - sudo echo "host all all ::1/128 trust" >> /etc/postgresql/9.4/main/pg_hba.conf + - sudo sh -c "echo 127.0.0.1 postgres >> /etc/hosts" + - sudo /etc/init.d/postgresql restart + + +language: python +python: + - 2.6 + - 2.7 + - 3.3 + - 3.4 + +env: + - EXTRAS=test + - EXTRAS=test_all + +install: + - pip install -e .[$EXTRAS] + +script: + - isort --recursive --diff sqlalchemy_utils tests && isort --recursive --check-only sqlalchemy_utils tests + - flake8 sqlalchemy_utils tests + - py.test diff --git a/python-sqlalchemy-utils/CHANGES.rst b/python-sqlalchemy-utils/CHANGES.rst new file mode 100644 index 0000000..923a080 --- /dev/null +++ b/python-sqlalchemy-utils/CHANGES.rst @@ -0,0 +1,1000 @@ +Changelog +--------- + +Here you can see the full list of changes between each SQLAlchemy-Utils release. + + +0.30.12 (2015-07-05) +^^^^^^^^^^^^^^^^^^^^ + +- Added support for PhoneNumber extensions (#121) + + +0.30.11 (2015-06-18) +^^^^^^^^^^^^^^^^^^^^ + +- Fix None type handling of ChoiceType +- Make locale casting for translation hybrid expressions cast locales on compilation phase. This extra lazy locale casting is needed in some cases where translation hybrid expressions are used before get_locale +function is available. + + +0.30.10 (2015-06-17) +^^^^^^^^^^^^^^^^^^^^ + +- Added better support for dynamic locales in translation_hybrid +- Make babel dependent primitive types to use Locale('en') for data validation instead of current locale. Using current locale leads to infinite recursion in cases where the loaded data has dependency to the loaded object's locale. + + +0.30.9 (2015-06-09) +^^^^^^^^^^^^^^^^^^^ + +- Added get_type utility function +- Added default parameter for array_agg function + + +0.30.8 (2015-06-05) +^^^^^^^^^^^^^^^^^^^ + +- Added Asterisk compiler +- Added row_to_json GenericFunction +- Added array_agg GenericFunction +- Made quote function accept dialect object as the first paremeter +- Made has_index work with tables without primary keys (#148) + + +0.30.7 (2015-05-28) +^^^^^^^^^^^^^^^^^^^ + +- Fixed CompositeType null handling + + +0.30.6 (2015-05-28) +^^^^^^^^^^^^^^^^^^^ + +- Made psycopg2 requirement optional (#145, #146) +- Made CompositeArray work with tuples given as bind parameters + + +0.30.5 (2015-05-27) +^^^^^^^^^^^^^^^^^^^ + +- Fixed CompositeType bind parameter processing when one of the fields is of TypeDecorator type and +CompositeType is used inside ARRAY type. + + +0.30.4 (2015-05-27) +^^^^^^^^^^^^^^^^^^^ + +- Fixed CompositeType bind parameter processing when one of the fields is of TypeDecorator type. + + +0.30.3 (2015-05-27) +^^^^^^^^^^^^^^^^^^^ + +- Added length property to range types +- Added CompositeType for PostgreSQL + + +0.30.2 (2015-05-21) +^^^^^^^^^^^^^^^^^^^ + +- Fixed ``assert_max_length``, ``assert_non_nullable``, ``assert_min_value`` and ``assert_max_value`` not properly raising an ``AssertionError`` when the assertion failed. + + +0.30.1 (2015-05-06) +^^^^^^^^^^^^^^^^^^^ + +- Drop undocumented batch fetch feature. Let's wait until the inner workings of SQLAlchemy loading API is well-documented. +- Fixed GenericRelationshipProperty comparator to work with SA 1.0.x (#139) +- Make all foreign key helpers SA 1.0 compliant +- Make translation_hybrid expression work the same way as SQLAlchemy-i18n translation expressions +- Update SQLAlchemy dependency to 1.0 + + +0.30.0 (2015-04-15) +^^^^^^^^^^^^^^^^^^^ + +- Added __hash__ method to Country class +- Made Country validate itself during object initialization +- Made Country string coercible +- Removed deprecated function generates +- Fixed observes function to work with simple column properties + + +0.29.9 (2015-04-07) +^^^^^^^^^^^^^^^^^^^ + +- Added CurrencyType (#19) and Currency class + + +0.29.8 (2015-03-03) +^^^^^^^^^^^^^^^^^^^ + +- Added get_class_by_table ORM utility function + + +0.29.7 (2015-03-01) +^^^^^^^^^^^^^^^^^^^ + +- Added Enum representation support for ChoiceType + + +0.29.6 (2015-02-03) +^^^^^^^^^^^^^^^^^^^ + +- Added customizable TranslationHybrid default value + + +0.29.5 (2015-02-03) +^^^^^^^^^^^^^^^^^^^ + +- Made assert_max_length support PostgreSQL array type + + +0.29.4 (2015-01-31) +^^^^^^^^^^^^^^^^^^^ + +- Made CaseInsensitiveComparator not cast already lowercased types to lowercase + + +0.29.3 (2015-01-24) +^^^^^^^^^^^^^^^^^^^ + +- Fixed analyze function runtime property handling for PostgreSQL >= 9.4 +- Fixed drop_database and create_database identifier quoting (#122) + + +0.29.2 (2015-01-08) +^^^^^^^^^^^^^^^^^^^ + +- Removed deprecated defer_except (SQLAlchemy's own load_only should be used from now on) +- Added json_sql PostgreSQL helper function + + +0.29.1 (2015-01-03) +^^^^^^^^^^^^^^^^^^^ + +- Added assert_min_value and assert_max_value testing functions + + +0.29.0 (2015-01-02) +^^^^^^^^^^^^^^^^^^^ + +- Removed TSVectorType.match_tsquery (now replaced by TSVectorType.match to be compatible with SQLAlchemy) +- Removed undocumented function tsvector_concat +- Added support for TSVectorType concatenation through OR operator +- Added documentation for TSVectorType (#102) + + +0.28.3 (2014-12-17) +^^^^^^^^^^^^^^^^^^^ + +- Made aggregated fully support column aliases +- Changed test matrix to run all tests without any optional dependencies (as well as with all optional dependencies) + + +0.28.2 (2014-12-13) +^^^^^^^^^^^^^^^^^^^ + +- Fixed issue with Color importing (#104) + + +0.28.1 (2014-12-13) +^^^^^^^^^^^^^^^^^^^ + +- Improved EncryptedType to support more underlying_type's; now supports: Integer, Boolean, Date, Time, DateTime, ColorType, PhoneNumberType, Unicode(Text), String(Text), Enum +- Allow a callable to be used to lookup the key for EncryptedType + + +0.28.0 (2014-12-12) +^^^^^^^^^^^^^^^^^^^ + +- Fixed PhoneNumber string coercion (#93) +- Added observes decorator (generates decorator will be deprecated later) + + +0.27.11 (2014-12-06) +^^^^^^^^^^^^^^^^^^^^ + +- Added loose typed column checking support for get_column_key +- Made get_column_key throw UnmappedColumnError to be consistent with SQLAlchemy + + +0.27.10 (2014-12-03) +^^^^^^^^^^^^^^^^^^^^ + +- Fixed column alias handling in dependent_objects + + +0.27.9 (2014-12-01) +^^^^^^^^^^^^^^^^^^^ + +- Fixed aggregated decorator many-to-many relationship handling +- Fixed aggregated column alias handling + + +0.27.8 (2014-11-13) +^^^^^^^^^^^^^^^^^^^ + +- Added is_loaded utility function +- Removed deprecated has_any_changes + + +0.27.7 (2014-11-03) +^^^^^^^^^^^^^^^^^^^ + +- Added support for Column and ColumnEntity objects in get_mapper +- Made make_order_by_deterministic add deterministic column more aggressively + + +0.27.6 (2014-10-29) +^^^^^^^^^^^^^^^^^^^ + +- Fixed assert_max_length not working with non nullable columns +- Add PostgreSQL < 9.2 support for drop_database + + +0.27.5 (2014-10-24) +^^^^^^^^^^^^^^^^^^^ + +- Made assert_* functions automatically rollback session +- Changed make_order_by_deterministic attach order by primary key for queries without order by +- Fixed alias handling in has_unique_index +- Fixed alias handling in has_index +- Fixed alias handling in make_order_by_deterministic + + +0.27.4 (2014-10-23) +^^^^^^^^^^^^^^^^^^^ + +- Added assert_non_nullable, assert_nullable and assert_max_length testing functions + + +0.27.3 (2014-10-22) +^^^^^^^^^^^^^^^^^^^ + +- Added supported for various SQLAlchemy objects in make_order_by_deterministic (previosly this function threw exceptions for other than Column objects) + + +0.27.2 (2014-10-21) +^^^^^^^^^^^^^^^^^^^ + +- Fixed MapperEntity handling in get_mapper and get_tables utility functions +- Fixed make_order_by_deterministic handling for queries without order by (no just silently ignores those rather than throws exception) +- Made make_order_by_deterministic if given query uses strings as order by args + + +0.27.1 (2014-10-20) +^^^^^^^^^^^^^^^^^^^ + +- Added support for more SQLAlchemy based objects and classes in get_tables function +- Added has_unique_index utility function +- Added make_order_by_deterministic utility function + + +0.27.0 (2014-10-14) +^^^^^^^^^^^^^^^^^^^ + +- Added EncryptedType + + +0.26.17 (2014-10-07) +^^^^^^^^^^^^^^^^^^^^ + +- Added explain and explain_analyze expressions +- Added analyze function + + +0.26.16 (2014-09-09) +^^^^^^^^^^^^^^^^^^^^ + +- Fix aggregate value handling for cascade deleted objects +- Fix ambiguous column sorting with join table inheritance in sort_query + + +0.26.15 (2014-08-28) +^^^^^^^^^^^^^^^^^^^^ + +- Fix sort_query support for queries using mappers (not declarative classes) with calculated column properties + + +0.26.14 (2014-08-26) +^^^^^^^^^^^^^^^^^^^^ + +- Added count method to QueryChain class + + +0.26.13 (2014-08-23) +^^^^^^^^^^^^^^^^^^^^ + +- Added template parameter to create_database function + + +0.26.12 (2014-08-22) +^^^^^^^^^^^^^^^^^^^^ + +- Added quote utility function + + +0.26.11 (2014-08-21) +^^^^^^^^^^^^^^^^^^^^ + +- Fixed dependent_objects support for single table inheritance + + +0.26.10 (2014-08-13) +^^^^^^^^^^^^^^^^^^^^ + +- Fixed dependent_objects support for multiple dependencies + + +0.26.9 (2014-08-06) +^^^^^^^^^^^^^^^^^^^ + +- Fixed PasswordType with Oracle dialect +- Added support for sort_query and attributes on mappers using with_polymorphic + + +0.26.8 (2014-07-30) +^^^^^^^^^^^^^^^^^^^ + +- Fixed order by column property handling in sort_query when using polymorphic inheritance +- Added support for synonym properties in sort_query + + +0.26.7 (2014-07-29) +^^^^^^^^^^^^^^^^^^^ + +- Made sort_query support hybrid properties where function name != property name +- Made get_hybrid_properties return a dictionary of property keys and hybrid properties +- Added documentation for get_hybrid_properties + + +0.26.6 (2014-07-22) +^^^^^^^^^^^^^^^^^^^ + +- Added exclude parameter to has_changes +- Made has_changes accept multiple attributes as second parameter + + +0.26.5 (2014-07-11) +^^^^^^^^^^^^^^^^^^^ + +- Added get_column_key +- Added Timestamp model mixin + + +0.26.4 (2014-06-25) +^^^^^^^^^^^^^^^^^^^ + +- Added auto_delete_orphans + + +0.26.3 (2014-06-25) +^^^^^^^^^^^^^^^^^^^ + +- Added has_any_changes + + +0.26.2 (2014-05-29) +^^^^^^^^^^^^^^^^^^^ + +- Added various fixes for bugs found in use of psycopg2 +- Added has_index + + +0.26.1 (2014-05-14) +^^^^^^^^^^^^^^^^^^^ + +- Added get_bind +- Added group_foreign_keys +- Added get_mapper +- Added merge_references + + +0.26.0 (2014-05-07) +^^^^^^^^^^^^^^^^^^^ + +- Added get_referencing_foreign_keys +- Added get_tables +- Added QueryChain +- Added dependent_objects + + +0.25.4 (2014-04-22) +^^^^^^^^^^^^^^^^^^^ + +- Added ExpressionParser + + +0.25.3 (2014-04-21) +^^^^^^^^^^^^^^^^^^^ + +- Added support for primary key aliases in get_primary_keys function +- Added get_columns utility function + + +0.25.2 (2014-03-25) +^^^^^^^^^^^^^^^^^^^ + +- Fixed sort_query handling of regular properties (no longer throws exceptions) + + +0.25.1 (2014-03-20) +^^^^^^^^^^^^^^^^^^^ + +- Added more import json as a fallback if anyjson package is not installed for JSONType +- Fixed query_entities labeled select handling + + +0.25.0 (2014-03-05) +^^^^^^^^^^^^^^^^^^^ + +- Added single table inheritance support for generic_relationship +- Added support for comparing class super types with generic relationships +- BC break: In order to support different inheritance strategies generic_relationship now uses class names as discriminators instead of table names. + + +0.24.4 (2014-03-05) +^^^^^^^^^^^^^^^^^^^ + +- Added hybrid_property support for generic_relationship + + +0.24.3 (2014-03-05) +^^^^^^^^^^^^^^^^^^^ + +- Added string argument support for generic_relationship +- Added composite primary key support for generic_relationship + + +0.24.2 (2014-03-04) +^^^^^^^^^^^^^^^^^^^ + +- Remove toolz from dependencies +- Add step argument support for all range types +- Optional intervals dependency updated to 0.2.4 + + +0.24.1 (2014-02-21) +^^^^^^^^^^^^^^^^^^^ + +- Made identity return a tuple in all cases +- Added support for declarative model classes as identity function's first argument + + +0.24.0 (2014-02-18) +^^^^^^^^^^^^^^^^^^^ + +- Added getdotattr +- Added Path and AttrPath classes +- SQLAlchemy dependency updated to 0.9.3 +- Optional intervals dependency updated to 0.2.2 + + +0.23.5 (2014-02-15) +^^^^^^^^^^^^^^^^^^^ + +- Fixed ArrowType timezone handling + + +0.23.4 (2014-01-30) +^^^^^^^^^^^^^^^^^^^ + +- Added force_instant_defaults function +- Added force_auto_coercion function +- Added source paramater for generates function + + +0.23.3 (2014-01-21) +^^^^^^^^^^^^^^^^^^^ + +- Fixed backref handling for aggregates +- Added support for many-to-many aggregates + + +0.23.2 (2014-01-21) +^^^^^^^^^^^^^^^^^^^ + +- Fixed issues with ColorType and ChoiceType string bound parameter processing +- Fixed inheritance handling with aggregates +- Fixed generic relationship nullifying + + +0.23.1 (2014-01-14) +^^^^^^^^^^^^^^^^^^^ + +- Added support for membership operators 'in' and 'not in' in range types +- Added support for contains and contained_by operators in range types +- Added range types to main module import + + +0.23.0 (2014-01-14) +^^^^^^^^^^^^^^^^^^^ + +- Deprecated NumberRangeType, NumberRange +- Added IntRangeType, NumericRangeType, DateRangeType, DateTimeRangeType +- Moved NumberRange functionality to intervals package + + +0.22.1 (2014-01-06) +^^^^^^^^^^^^^^^^^^^ + +- Fixed in issue where NumberRange would not always raise RangeBoundsException with object initialization + + +0.22.0 (2014-01-04) +^^^^^^^^^^^^^^^^^^^ + +- Added SQLAlchemy 0.9 support +- Made JSONType use sqlalchemy.dialects.postgresql.JSON if available +- Updated psycopg requirement to 2.5.1 +- Deprecated NumberRange classmethod constructors + + +0.21.0 (2013-11-11) +^^^^^^^^^^^^^^^^^^^ + +- Added support for cached aggregates + + +0.20.0 (2013-10-24) +^^^^^^^^^^^^^^^^^^^ + +- Added JSONType +- NumberRangeType now supports coercing of integer values + + +0.19.0 (2013-10-24) +^^^^^^^^^^^^^^^^^^^ + +- Added ChoiceType + + +0.18.0 (2013-10-24) +^^^^^^^^^^^^^^^^^^^ + +- Added LocaleType + + +0.17.1 (2013-10-23) +^^^^^^^^^^^^^^^^^^^ + +- Removed compat module, added total_ordering package to Python 2.6 requirements +- Enhanced render_statement function + + +0.17.0 (2013-10-23) +^^^^^^^^^^^^^^^^^^^ + +- Added URLType + + +0.16.25 (2013-10-18) +^^^^^^^^^^^^^^^^^^^^ + +- Added __ne__ operator implementation for Country object +- New utility function: naturally_equivalent + + +0.16.24 (2013-10-04) +^^^^^^^^^^^^^^^^^^^^ + +- Renamed match operator of TSVectorType to match_tsquery in order to avoid confusion with existing match operator +- Added catalog parameter support for match_tsquery operator + + +0.16.23 (2013-10-04) +^^^^^^^^^^^^^^^^^^^^ + +- Added match operator for TSVectorType + + +0.16.22 (2013-10-03) +^^^^^^^^^^^^^^^^^^^^ + +- Added optional columns and options parameter for TSVectorType + + +0.16.21 (2013-09-29) +^^^^^^^^^^^^^^^^^^^^ + +- Fixed an issue with sort_query where sort by relationship property would cause an exception. + + +0.16.20 (2013-09-26) +^^^^^^^^^^^^^^^^^^^^ + +- Fixed an issue with sort_query where sort by main entity's attribute would fail if joins where applied. + + +0.16.19 (2013-09-21) +^^^^^^^^^^^^^^^^^^^^ + +- Added configuration for silent mode in sort_query +- Added support for aliased entity hybrid properties in sort_query + + +0.16.18 (2013-09-19) +^^^^^^^^^^^^^^^^^^^^ + +- Fixed sort_query hybrid property handling (again) + + +0.16.17 (2013-09-19) +^^^^^^^^^^^^^^^^^^^^ + +- Added support for relation hybrid property sorting in sort_query + + +0.16.16 (2013-09-18) +^^^^^^^^^^^^^^^^^^^^ + +- Fixed fatal bug in batch fetch join table inheritance handling (not handling one-to-many relations properly) + + +0.16.15 (2013-09-17) +^^^^^^^^^^^^^^^^^^^^ + +- Fixed sort_query hybrid property handling (now supports both ascending and descending sorting) + + +0.16.14 (2013-09-17) +^^^^^^^^^^^^^^^^^^^^ + +- More pythonic __init__ for Country allowing Country(Country('fi')) == Country('fi') +- Better equality operator for Country + + +0.16.13 (2013-09-17) +^^^^^^^^^^^^^^^^^^^^ + +- Added i18n module for configuration of locale dependant types + + +0.16.12 (2013-09-17) +^^^^^^^^^^^^^^^^^^^^ + +- Fixed remaining Python 3 issues with WeekDaysType +- Better bound method handling for WeekDay get_locale + + +0.16.11 (2013-09-17) +^^^^^^^^^^^^^^^^^^^^ + +- Python 3 support for WeekDaysType +- Fixed a bug in batch fetch for situations where joined paths contain zero entitites + + +0.16.10 (2013-09-16) +^^^^^^^^^^^^^^^^^^^^ + +- Added WeekDaysType + + +0.16.9 (2013-08-21) +^^^^^^^^^^^^^^^^^^^ + +- Support for many-to-one directed relationship properties batch fetching + + +0.16.8 (2013-08-21) +^^^^^^^^^^^^^^^^^^^ + +- PasswordType support for PostgreSQL +- Hybrid property for sort_query + + +0.16.7 (2013-08-18) +^^^^^^^^^^^^^^^^^^^ + +- Added better handling of local column names in batch_fetch +- PasswordType gets default length even if no crypt context schemes provided + + +0.16.6 (2013-08-16) +^^^^^^^^^^^^^^^^^^^ + +- Rewritten batch_fetch schematics, new syntax for backref population + + +0.16.5 (2013-08-08) +^^^^^^^^^^^^^^^^^^^ + +- Initial backref population forcing support for batch_fetch + + +0.16.4 (2013-08-08) +^^^^^^^^^^^^^^^^^^^ + +- Initial many-to-many relations support for batch_fetch + + +0.16.3 (2013-08-05) +^^^^^^^^^^^^^^^^^^^ + +- Added batch_fetch function + + +0.16.2 (2013-08-01) +^^^^^^^^^^^^^^^^^^^ + +- Added to_tsquery and plainto_tsquery sql function expressions + + +0.16.1 (2013-08-01) +^^^^^^^^^^^^^^^^^^^ + +- Added tsvector_concat and tsvector_match sql function expressions + + +0.16.0 (2013-07-25) +^^^^^^^^^^^^^^^^^^^ + +- Added ArrowType + + +0.15.1 (2013-07-22) +^^^^^^^^^^^^^^^^^^^ + +- Added utility functions declarative_base, identity and is_auto_assigned_date_column + + +0.15.0 (2013-07-22) +^^^^^^^^^^^^^^^^^^^ + +- Added PasswordType + + +0.14.7 (2013-07-22) +^^^^^^^^^^^^^^^^^^^ + +- Lazy import for ipaddress package + + +0.14.6 (2013-07-22) +^^^^^^^^^^^^^^^^^^^ + +- Fixed UUID import issues + + +0.14.5 (2013-07-22) +^^^^^^^^^^^^^^^^^^^ + +- Added UUID type + + +0.14.4 (2013-07-03) +^^^^^^^^^^^^^^^^^^^ + +- Added TSVector type + + +0.14.3 (2013-07-03) +^^^^^^^^^^^^^^^^^^^ + +- Added non_indexed_foreign_keys utility function + + +0.14.2 (2013-07-02) +^^^^^^^^^^^^^^^^^^^ + +- Fixed py3 bug introduced in 0.14.1 + + +0.14.1 (2013-07-02) +^^^^^^^^^^^^^^^^^^^ + +- Made sort_query support column_property selects with labels + + +0.14.0 (2013-07-02) +^^^^^^^^^^^^^^^^^^^ + +- Python 3 support, dropped python 2.5 support + + +0.13.3 (2013-06-11) +^^^^^^^^^^^^^^^^^^^ + +- Initial support for psycopg 2.5 NumericRange objects + + +0.13.2 (2013-06-11) +^^^^^^^^^^^^^^^^^^^ + +- QuerySorter now threadsafe. + + +0.13.1 (2013-06-11) +^^^^^^^^^^^^^^^^^^^ + +- Made sort_query function support multicolumn sorting. + + +0.13.0 (2013-06-05) +^^^^^^^^^^^^^^^^^^^ + +- Added table_name utility function. + + +0.12.5 (2013-06-05) +^^^^^^^^^^^^^^^^^^^ + +- ProxyDict now contains None values in cache - more efficient contains method. + + +0.12.4 (2013-06-01) +^^^^^^^^^^^^^^^^^^^ + +- Fixed ProxyDict contains method + + +0.12.3 (2013-05-30) +^^^^^^^^^^^^^^^^^^^ + +- Proxy dict expiration listener from function scope to global scope + + +0.12.2 (2013-05-29) +^^^^^^^^^^^^^^^^^^^ + +- Added automatic expiration of proxy dicts + + + +0.12.1 (2013-05-18) +^^^^^^^^^^^^^^^^^^^ + +- Added utility functions remove_property and primary_keys + + + +0.12.0 (2013-05-17) +^^^^^^^^^^^^^^^^^^^ + +- Added ProxyDict + + +0.11.0 (2013-05-08) +^^^^^^^^^^^^^^^^^^^ + +- Added coercion_listener + + +0.10.0 (2013-04-29) +^^^^^^^^^^^^^^^^^^^ + +- Added ColorType + + +0.9.1 (2013-04-15) +^^^^^^^^^^^^^^^^^^ + +- Renamed Email to EmailType and ScalarList to ScalarListType (unified type class naming convention) + + +0.9.0 (2013-04-11) +^^^^^^^^^^^^^^^^^^ + +- Added CaseInsensitiveComparator +- Added Email type + + +0.8.4 (2013-04-08) +^^^^^^^^^^^^^^^^^^ + +- Added sort by aliased and joined entity + + +0.8.3 (2013-04-03) +^^^^^^^^^^^^^^^^^^ + +- sort_query now supports labeled and subqueried scalars + + +0.8.2 (2013-04-03) +^^^^^^^^^^^^^^^^^^ + +- Fixed empty ScalarList handling + + +0.8.1 (2013-04-03) +^^^^^^^^^^^^^^^^^^ + +- Removed unnecessary print statement form ScalarList +- Documentation for ScalarList and NumberRange + + +0.8.0 (2013-04-02) +^^^^^^^^^^^^^^^^^^ + +- Added ScalarList type +- Fixed NumberRange bind param and result value processing + + +0.7.7 (2013-03-27) +^^^^^^^^^^^^^^^^^^ + +- Changed PhoneNumber string representation to the national phone number format + + +0.7.6 (2013-03-26) +^^^^^^^^^^^^^^^^^^ + +- NumberRange now wraps ValueErrors as NumberRangeExceptions + + +0.7.5 (2013-03-26) +^^^^^^^^^^^^^^^^^^ + +- Fixed defer_except +- Better string representations for NumberRange + + +0.7.4 (2013-03-26) +^^^^^^^^^^^^^^^^^^ + +- Fixed NumberRange upper bound parsing + + +0.7.3 (2013-03-26) +^^^^^^^^^^^^^^^^^^ + +- Enabled PhoneNumberType None value storing + + +0.7.2 (2013-03-26) +^^^^^^^^^^^^^^^^^^ + +- Enhanced string parsing for NumberRange + + +0.7.1 (2013-03-26) +^^^^^^^^^^^^^^^^^^ + +- Fixed requirements (now supports SQLAlchemy 0.8) + + +0.7.0 (2013-03-26) +^^^^^^^^^^^^^^^^^^ + +- Added NumberRange type + + + +0.6.0 (2013-03-26) +^^^^^^^^^^^^^^^^^^ + +- Extended PhoneNumber class from python-phonenumbers library + + +0.5.0 (2013-03-20) +^^^^^^^^^^^^^^^^^^ + +- Added PhoneNumberType type decorator + + +0.4.0 (2013-03-01) +^^^^^^^^^^^^^^^^^^ + +- Renamed SmartList to InstrumentedList +- Added instrumented_list decorator + + +0.3.0 (2013-03-01) +^^^^^^^^^^^^^^^^^^ + +- Added new collection class SmartList + + +0.2.0 (2013-03-01) +^^^^^^^^^^^^^^^^^^ + +- Added new function defer_except() + + +0.1.0 (2013-01-12) +^^^^^^^^^^^^^^^^^^ + +- Initial public release diff --git a/python-sqlalchemy-utils/LICENSE b/python-sqlalchemy-utils/LICENSE new file mode 100644 index 0000000..d604ce8 --- /dev/null +++ b/python-sqlalchemy-utils/LICENSE @@ -0,0 +1,27 @@ +Copyright (c) 2012, Konsta Vesterinen + +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +* The names of the contributors may not be used to endorse or promote products + derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER BE LIABLE FOR ANY DIRECT, +INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, +BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE +OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/python-sqlalchemy-utils/MANIFEST.in b/python-sqlalchemy-utils/MANIFEST.in new file mode 100644 index 0000000..cd07949 --- /dev/null +++ b/python-sqlalchemy-utils/MANIFEST.in @@ -0,0 +1,7 @@ +include CHANGES.rst LICENSE README.rst +recursive-include tests * +recursive-exclude tests *.pyc +recursive-include docs * +recursive-exclude docs *.pyc +prune docs/_build +exclude docs/_themes/.git diff --git a/python-sqlalchemy-utils/README.rst b/python-sqlalchemy-utils/README.rst new file mode 100644 index 0000000..6f3f57e --- /dev/null +++ b/python-sqlalchemy-utils/README.rst @@ -0,0 +1,22 @@ +SQLAlchemy-Utils +================ + +|Build Status| |Version Status| |Downloads| + + +Various utility functions, new data types and helpers for SQLAlchemy. + + +Resources +--------- + +- `Documentation `_ +- `Issue Tracker `_ +- `Code `_ + +.. |Build Status| image:: https://travis-ci.org/kvesteri/sqlalchemy-utils.svg?branch=master + :target: https://travis-ci.org/kvesteri/sqlalchemy-utils +.. |Version Status| image:: https://img.shields.io/pypi/v/SQLAlchemy-Utils.svg + :target: https://pypi.python.org/pypi/SQLAlchemy-Utils/ +.. |Downloads| image:: https://img.shields.io/pypi/dm/SQLAlchemy-Utils.svg + :target: https://pypi.python.org/pypi/SQLAlchemy-Utils/ diff --git a/python-sqlalchemy-utils/ROADMAP.rst b/python-sqlalchemy-utils/ROADMAP.rst new file mode 100644 index 0000000..3d634ea --- /dev/null +++ b/python-sqlalchemy-utils/ROADMAP.rst @@ -0,0 +1,9 @@ +* Add efficient pagination support + http://www.depesz.com/2011/05/20/pagination-with-fixed-order/ + http://stackoverflow.com/questions/6618366/improving-offset-performance-in-postgresql +* Generic file model + https://github.com/jpvanhal/silo +* Query to Postgres JSON converter + http://hashrocket.com/blog/posts/faster-json-generation-with-postgresql +* Postgres Cube datatype: + http://www.postgresql.org/docs/9.4/static/cube.html diff --git a/python-sqlalchemy-utils/docs/Makefile b/python-sqlalchemy-utils/docs/Makefile new file mode 100644 index 0000000..047bde5 --- /dev/null +++ b/python-sqlalchemy-utils/docs/Makefile @@ -0,0 +1,153 @@ +# Makefile for Sphinx documentation +# + +# You can set these variables from the command line. +SPHINXOPTS = +SPHINXBUILD = sphinx-build +PAPER = +BUILDDIR = _build + +# Internal variables. +PAPEROPT_a4 = -D latex_paper_size=a4 +PAPEROPT_letter = -D latex_paper_size=letter +ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . +# the i18n builder cannot share the environment and doctrees with the others +I18NSPHINXOPTS = $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . + +.PHONY: help clean html dirhtml singlehtml pickle json htmlhelp qthelp devhelp epub latex latexpdf text man changes linkcheck doctest gettext + +help: + @echo "Please use \`make ' where is one of" + @echo " html to make standalone HTML files" + @echo " dirhtml to make HTML files named index.html in directories" + @echo " singlehtml to make a single large HTML file" + @echo " pickle to make pickle files" + @echo " json to make JSON files" + @echo " htmlhelp to make HTML files and a HTML help project" + @echo " qthelp to make HTML files and a qthelp project" + @echo " devhelp to make HTML files and a Devhelp project" + @echo " epub to make an epub" + @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter" + @echo " latexpdf to make LaTeX files and run them through pdflatex" + @echo " text to make text files" + @echo " man to make manual pages" + @echo " texinfo to make Texinfo files" + @echo " info to make Texinfo files and run them through makeinfo" + @echo " gettext to make PO message catalogs" + @echo " changes to make an overview of all changed/added/deprecated items" + @echo " linkcheck to check all external links for integrity" + @echo " doctest to run all doctests embedded in the documentation (if enabled)" + +clean: + -rm -rf $(BUILDDIR)/* + +html: + $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html + @echo + @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." + +dirhtml: + $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml + @echo + @echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml." + +singlehtml: + $(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml + @echo + @echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml." + +pickle: + $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle + @echo + @echo "Build finished; now you can process the pickle files." + +json: + $(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json + @echo + @echo "Build finished; now you can process the JSON files." + +htmlhelp: + $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp + @echo + @echo "Build finished; now you can run HTML Help Workshop with the" \ + ".hhp project file in $(BUILDDIR)/htmlhelp." + +qthelp: + $(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp + @echo + @echo "Build finished; now you can run "qcollectiongenerator" with the" \ + ".qhcp project file in $(BUILDDIR)/qthelp, like this:" + @echo "# qcollectiongenerator $(BUILDDIR)/qthelp/SQLAlchemy-Utils.qhcp" + @echo "To view the help file:" + @echo "# assistant -collectionFile $(BUILDDIR)/qthelp/SQLAlchemy-Utils.qhc" + +devhelp: + $(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp + @echo + @echo "Build finished." + @echo "To view the help file:" + @echo "# mkdir -p $$HOME/.local/share/devhelp/SQLAlchemy-Utils" + @echo "# ln -s $(BUILDDIR)/devhelp $$HOME/.local/share/devhelp/SQLAlchemy-Utils" + @echo "# devhelp" + +epub: + $(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub + @echo + @echo "Build finished. The epub file is in $(BUILDDIR)/epub." + +latex: + $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex + @echo + @echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex." + @echo "Run \`make' in that directory to run these through (pdf)latex" \ + "(use \`make latexpdf' here to do that automatically)." + +latexpdf: + $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex + @echo "Running LaTeX files through pdflatex..." + $(MAKE) -C $(BUILDDIR)/latex all-pdf + @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." + +text: + $(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text + @echo + @echo "Build finished. The text files are in $(BUILDDIR)/text." + +man: + $(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man + @echo + @echo "Build finished. The manual pages are in $(BUILDDIR)/man." + +texinfo: + $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo + @echo + @echo "Build finished. The Texinfo files are in $(BUILDDIR)/texinfo." + @echo "Run \`make' in that directory to run these through makeinfo" \ + "(use \`make info' here to do that automatically)." + +info: + $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo + @echo "Running Texinfo files through makeinfo..." + make -C $(BUILDDIR)/texinfo info + @echo "makeinfo finished; the Info files are in $(BUILDDIR)/texinfo." + +gettext: + $(SPHINXBUILD) -b gettext $(I18NSPHINXOPTS) $(BUILDDIR)/locale + @echo + @echo "Build finished. The message catalogs are in $(BUILDDIR)/locale." + +changes: + $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes + @echo + @echo "The overview file is in $(BUILDDIR)/changes." + +linkcheck: + $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck + @echo + @echo "Link check complete; look for any errors in the above output " \ + "or in $(BUILDDIR)/linkcheck/output.txt." + +doctest: + $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest + @echo "Testing of doctests in the sources finished, look at the " \ + "results in $(BUILDDIR)/doctest/output.txt." diff --git a/python-sqlalchemy-utils/docs/aggregates.rst b/python-sqlalchemy-utils/docs/aggregates.rst new file mode 100644 index 0000000..5d262ff --- /dev/null +++ b/python-sqlalchemy-utils/docs/aggregates.rst @@ -0,0 +1,6 @@ +Aggregated attributes +===================== + +.. automodule:: sqlalchemy_utils.aggregates + +.. autofunction:: aggregated diff --git a/python-sqlalchemy-utils/docs/conf.py b/python-sqlalchemy-utils/docs/conf.py new file mode 100644 index 0000000..734ee21 --- /dev/null +++ b/python-sqlalchemy-utils/docs/conf.py @@ -0,0 +1,247 @@ +# -*- coding: utf-8 -*- +# +# SQLAlchemy-Utils documentation build configuration file, created by +# sphinx-quickstart on Tue Feb 19 11:16:09 2013. +# +# This file is execfile()d with the current directory set to its containing dir. +# +# Note that not all possible configuration values are present in this +# autogenerated file. +# +# All configuration values have a default; values that are commented out +# serve to show the default. + +import sys, os + +# If extensions (or modules to document with autodoc) are in another directory, +# add these directories to sys.path here. If the directory is relative to the +# documentation root, use os.path.abspath to make it absolute, like shown here. +sys.path.insert(0, os.path.abspath('..')) +from sqlalchemy_utils import __version__ + +# -- General configuration ----------------------------------------------------- + +# If your documentation needs a minimal Sphinx version, state it here. +#needs_sphinx = '1.0' + +# Add any Sphinx extension module names here, as strings. They can be extensions +# coming with Sphinx (named 'sphinx.ext.*') or your custom ones. +extensions = ['sphinx.ext.autodoc', 'sphinx.ext.doctest', 'sphinx.ext.intersphinx', 'sphinx.ext.todo', 'sphinx.ext.coverage', 'sphinx.ext.ifconfig', 'sphinx.ext.viewcode'] + +# Add any paths that contain templates here, relative to this directory. +templates_path = ['_templates'] + +# The suffix of source filenames. +source_suffix = '.rst' + +# The encoding of source files. +#source_encoding = 'utf-8-sig' + +# The master toctree document. +master_doc = 'index' + +# General information about the project. +project = u'SQLAlchemy-Utils' +copyright = u'2013, Konsta Vesterinen' + +# The version info for the project you're documenting, acts as replacement for +# |version| and |release|, also used in various other places throughout the +# built documents. +# +# The short X.Y version. +version = __version__ +# The full version, including alpha/beta/rc tags. +release = version + +# The language for content autogenerated by Sphinx. Refer to documentation +# for a list of supported languages. +#language = None + +# There are two options for replacing |today|: either, you set today to some +# non-false value, then it is used: +#today = '' +# Else, today_fmt is used as the format for a strftime call. +#today_fmt = '%B %d, %Y' + +# List of patterns, relative to source directory, that match files and +# directories to ignore when looking for source files. +exclude_patterns = ['_build'] + +# The reST default role (used for this markup: `text`) to use for all documents. +#default_role = None + +# If true, '()' will be appended to :func: etc. cross-reference text. +#add_function_parentheses = True + +# If true, the current module name will be prepended to all description +# unit titles (such as .. function::). +#add_module_names = True + +# If true, sectionauthor and moduleauthor directives will be shown in the +# output. They are ignored by default. +#show_authors = False + +# The name of the Pygments (syntax highlighting) style to use. +pygments_style = 'sphinx' + +# A list of ignored prefixes for module index sorting. +#modindex_common_prefix = [] + + +# -- Options for HTML output --------------------------------------------------- + +# The theme to use for HTML and HTML Help pages. See the documentation for +# a list of builtin themes. +html_theme = 'default' + +# Theme options are theme-specific and customize the look and feel of a theme +# further. For a list of options available for each theme, see the +# documentation. +#html_theme_options = {} + +# Add any paths that contain custom themes here, relative to this directory. +#html_theme_path = [] + +# The name for this set of Sphinx documents. If None, it defaults to +# " v documentation". +#html_title = None + +# A shorter title for the navigation bar. Default is the same as html_title. +#html_short_title = None + +# The name of an image file (relative to this directory) to place at the top +# of the sidebar. +#html_logo = None + +# The name of an image file (within the static path) to use as favicon of the +# docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 +# pixels large. +#html_favicon = None + +# Add any paths that contain custom static files (such as style sheets) here, +# relative to this directory. They are copied after the builtin static files, +# so a file named "default.css" will overwrite the builtin "default.css". +html_static_path = ['_static'] + +# If not '', a 'Last updated on:' timestamp is inserted at every page bottom, +# using the given strftime format. +#html_last_updated_fmt = '%b %d, %Y' + +# If true, SmartyPants will be used to convert quotes and dashes to +# typographically correct entities. +#html_use_smartypants = True + +# Custom sidebar templates, maps document names to template names. +#html_sidebars = {} + +# Additional templates that should be rendered to pages, maps page names to +# template names. +#html_additional_pages = {} + +# If false, no module index is generated. +#html_domain_indices = True + +# If false, no index is generated. +#html_use_index = True + +# If true, the index is split into individual pages for each letter. +#html_split_index = False + +# If true, links to the reST sources are added to the pages. +#html_show_sourcelink = True + +# If true, "Created using Sphinx" is shown in the HTML footer. Default is True. +#html_show_sphinx = True + +# If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. +#html_show_copyright = True + +# If true, an OpenSearch description file will be output, and all pages will +# contain a tag referring to it. The value of this option must be the +# base URL from which the finished HTML is served. +#html_use_opensearch = '' + +# This is the file name suffix for HTML files (e.g. ".xhtml"). +#html_file_suffix = None + +# Output file base name for HTML help builder. +htmlhelp_basename = 'SQLAlchemy-Utilsdoc' + + +# -- Options for LaTeX output -------------------------------------------------- + +latex_elements = { +# The paper size ('letterpaper' or 'a4paper'). +#'papersize': 'letterpaper', + +# The font size ('10pt', '11pt' or '12pt'). +#'pointsize': '10pt', + +# Additional stuff for the LaTeX preamble. +#'preamble': '', +} + +# Grouping the document tree into LaTeX files. List of tuples +# (source start file, target name, title, author, documentclass [howto/manual]). +latex_documents = [ + ('index', 'SQLAlchemy-Utils.tex', u'SQLAlchemy-Utils Documentation', + u'Konsta Vesterinen', 'manual'), +] + +# The name of an image file (relative to this directory) to place at the top of +# the title page. +#latex_logo = None + +# For "manual" documents, if this is true, then toplevel headings are parts, +# not chapters. +#latex_use_parts = False + +# If true, show page references after internal links. +#latex_show_pagerefs = False + +# If true, show URL addresses after external links. +#latex_show_urls = False + +# Documents to append as an appendix to all manuals. +#latex_appendices = [] + +# If false, no module index is generated. +#latex_domain_indices = True + + +# -- Options for manual page output -------------------------------------------- + +# One entry per manual page. List of tuples +# (source start file, name, description, authors, manual section). +man_pages = [ + ('index', 'sqlalchemy-utils', u'SQLAlchemy-Utils Documentation', + [u'Konsta Vesterinen'], 1) +] + +# If true, show URL addresses after external links. +#man_show_urls = False + + +# -- Options for Texinfo output ------------------------------------------------ + +# Grouping the document tree into Texinfo files. List of tuples +# (source start file, target name, title, author, +# dir menu entry, description, category) +texinfo_documents = [ + ('index', 'SQLAlchemy-Utils', u'SQLAlchemy-Utils Documentation', + u'Konsta Vesterinen', 'SQLAlchemy-Utils', 'One line description of project.', + 'Miscellaneous'), +] + +# Documents to append as an appendix to all manuals. +#texinfo_appendices = [] + +# If false, no module index is generated. +#texinfo_domain_indices = True + +# How to display URL addresses: 'footnote', 'no', or 'inline'. +#texinfo_show_urls = 'footnote' + + +# Example configuration for intersphinx: refer to the Python standard library. +intersphinx_mapping = {'http://docs.python.org/': None} diff --git a/python-sqlalchemy-utils/docs/data_types.rst b/python-sqlalchemy-utils/docs/data_types.rst new file mode 100644 index 0000000..87fc895 --- /dev/null +++ b/python-sqlalchemy-utils/docs/data_types.rst @@ -0,0 +1,163 @@ +Data types +========== + +SQLAlchemy-Utils provides various new data types for SQLAlchemy. In order to gain full +advantage of these datatypes you should use automatic data coercion. See :func:`force_auto_coercion` for how to set up this feature. + +.. module:: sqlalchemy_utils.types + + +ArrowType +--------- + +.. module:: sqlalchemy_utils.types.arrow + +.. autoclass:: ArrowType + + +ChoiceType +---------- + +.. module:: sqlalchemy_utils.types.choice + +.. autoclass:: ChoiceType + + +ColorType +--------- + +.. module:: sqlalchemy_utils.types.color + +.. autoclass:: ColorType + + +CompositeType +------------- + +.. automodule:: sqlalchemy_utils.types.pg_composite + +.. autoclass:: CompositeType + + +CountryType +----------- + +.. module:: sqlalchemy_utils.types.country + +.. autoclass:: CountryType + +.. module:: sqlalchemy_utils.primitives.country + +.. autoclass:: Country + + +CurrencyType +------------ + +.. module:: sqlalchemy_utils.types.currency + +.. autoclass:: CurrencyType + +.. module:: sqlalchemy_utils.primitives.currency + +.. autoclass:: Currency + + +EncryptedType +------------- + +.. module:: sqlalchemy_utils.types.encrypted + +.. autoclass:: EncryptedType + +JSONType +-------- + +.. module:: sqlalchemy_utils.types.json + +.. autoclass:: JSONType + + +LocaleType +---------- + + +.. module:: sqlalchemy_utils.types.locale + +.. autoclass:: LocaleType + + +IPAddressType +------------- + +.. module:: sqlalchemy_utils.types.ip_address + +.. autoclass:: IPAddressType + + +PasswordType +------------ + +.. module:: sqlalchemy_utils.types.password + +.. autoclass:: PasswordType + + +PhoneNumberType +--------------- + +.. module:: sqlalchemy_utils.types.phone_number + +.. autoclass:: PhoneNumberType + + +ScalarListType +-------------- + +.. module:: sqlalchemy_utils.types.scalar_list + +.. autoclass:: ScalarListType + + +TimezoneType +------------ + + +.. module:: sqlalchemy_utils.types.timezone + +.. autoclass:: TimezoneType + + +TSVectorType +------------ + +.. module:: sqlalchemy_utils.types.ts_vector + +.. autoclass:: TSVectorType + + +URLType +------- + +.. module:: sqlalchemy_utils.types.url + +.. autoclass:: URLType + + +UUIDType +-------- + + +.. module:: sqlalchemy_utils.types.uuid + +.. autoclass:: UUIDType + + + +WeekDaysType +------------ + +.. module:: sqlalchemy_utils.types.weekdays + +.. autoclass:: WeekDaysType + diff --git a/python-sqlalchemy-utils/docs/database_helpers.rst b/python-sqlalchemy-utils/docs/database_helpers.rst new file mode 100644 index 0000000..df349ba --- /dev/null +++ b/python-sqlalchemy-utils/docs/database_helpers.rst @@ -0,0 +1,59 @@ +Database helpers +================ + + +.. module:: sqlalchemy_utils.functions + + +analyze +------- + +.. autofunction:: analyze + + +database_exists +--------------- + +.. autofunction:: database_exists + + +create_database +--------------- + +.. autofunction:: create_database + + +drop_database +------------- + +.. autofunction:: drop_database + + +has_index +--------- + +.. autofunction:: has_index + + +has_unique_index +---------------- + +.. autofunction:: has_unique_index + + +json_sql +-------- + +.. autofunction:: json_sql + + +render_expression +----------------- + +.. autofunction:: render_expression + + +render_statement +---------------- + +.. autofunction:: render_statement diff --git a/python-sqlalchemy-utils/docs/foreign_key_helpers.rst b/python-sqlalchemy-utils/docs/foreign_key_helpers.rst new file mode 100644 index 0000000..0d309a4 --- /dev/null +++ b/python-sqlalchemy-utils/docs/foreign_key_helpers.rst @@ -0,0 +1,40 @@ +Foreign key helpers +=================== + +.. module:: sqlalchemy_utils.functions + + +dependent_objects +----------------- + +.. autofunction:: dependent_objects + + +get_referencing_foreign_keys +---------------------------- + +.. autofunction:: get_referencing_foreign_keys + + +group_foreign_keys +------------------ + +.. autofunction:: group_foreign_keys + + +is_indexed_foreign_key +---------------------- + +.. autofunction:: is_indexed_foreign_key + + +merge_references +---------------- + +.. autofunction:: merge_references + + +non_indexed_foreign_keys +------------------------ + +.. autofunction:: non_indexed_foreign_keys diff --git a/python-sqlalchemy-utils/docs/generic_relationship.rst b/python-sqlalchemy-utils/docs/generic_relationship.rst new file mode 100644 index 0000000..bb61ecb --- /dev/null +++ b/python-sqlalchemy-utils/docs/generic_relationship.rst @@ -0,0 +1,171 @@ +Generic relationships +===================== + +Generic relationship is a form of relationship that supports creating a 1 to many relationship to any target model. + +:: + + from sqlalchemy_utils import generic_relationship + + class User(Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + + class Customer(Base): + __tablename__ = 'customer' + id = sa.Column(sa.Integer, primary_key=True) + + class Event(Base): + __tablename__ = 'event' + id = sa.Column(sa.Integer, primary_key=True) + + # This is used to discriminate between the linked tables. + object_type = sa.Column(sa.Unicode(255)) + + # This is used to point to the primary key of the linked row. + object_id = sa.Column(sa.Integer) + + object = generic_relationship(object_type, object_id) + + + # Some general usage to attach an event to a user. + user = User() + customer = Customer() + + session.add_all([user, customer]) + session.commit() + + ev = Event() + ev.object = user + + session.add(ev) + session.commit() + + # Find the event we just made. + session.query(Event).filter_by(object=user).first() + + # Find any events that are bound to users. + session.query(Event).filter(Event.object.is_type(User)).all() + + +Inheritance +----------- + +:: + + class Employee(self.Base): + __tablename__ = 'employee' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.String(50)) + type = sa.Column(sa.String(20)) + + __mapper_args__ = { + 'polymorphic_on': type, + 'polymorphic_identity': 'employee' + } + + class Manager(Employee): + __mapper_args__ = { + 'polymorphic_identity': 'manager' + } + + class Engineer(Employee): + __mapper_args__ = { + 'polymorphic_identity': 'engineer' + } + + class Activity(self.Base): + __tablename__ = 'event' + id = sa.Column(sa.Integer, primary_key=True) + + object_type = sa.Column(sa.Unicode(255)) + object_id = sa.Column(sa.Integer, nullable=False) + + object = generic_relationship(object_type, object_id) + + +Now same as before we can add some objects:: + + manager = Manager() + + session.add(manager) + session.commit() + + activity = Activity() + activity.object = manager + + session.add(activity) + session.commit() + + # Find the activity we just made. + session.query(Event).filter_by(object=manager).first() + + +We can even test super types:: + + + session.query(Activity).filter(Event.object.is_type(Employee)).all() + + +Abstract base classes +--------------------- + +Generic relationships also allows using string arguments. When using generic_relationship with abstract base classes you need to set up the relationship using declared_attr decorator and string arguments. + + +:: + + + class Building(self.Base): + __tablename__ = 'building' + id = sa.Column(sa.Integer, primary_key=True) + + class User(self.Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + + class EventBase(self.Base): + __abstract__ = True + + object_type = sa.Column(sa.Unicode(255)) + object_id = sa.Column(sa.Integer, nullable=False) + + @declared_attr + def object(cls): + return generic_relationship('object_type', 'object_id') + + class Event(EventBase): + __tablename__ = 'event' + id = sa.Column(sa.Integer, primary_key=True) + + +Composite keys +-------------- + +For some very rare cases you may need to use generic_relationships with composite primary keys. There is a limitation here though: you can only set up generic_relationship for similar composite primary key types. In other words you can't mix generic relationship to both composite keyed objects and single keyed objects. + +:: + + from sqlalchemy_utils import generic_relationship + + + class Customer(Base): + __tablename__ = 'customer' + code1 = sa.Column(sa.Integer, primary_key=True) + code2 = sa.Column(sa.Integer, primary_key=True) + + + class Event(Base): + __tablename__ = 'event' + id = sa.Column(sa.Integer, primary_key=True) + + # This is used to discriminate between the linked tables. + object_type = sa.Column(sa.Unicode(255)) + + object_code1 = sa.Column(sa.Integer) + + object_code2 = sa.Column(sa.Integer) + + object = generic_relationship( + object_type, (object_code1, object_code2) + ) diff --git a/python-sqlalchemy-utils/docs/index.rst b/python-sqlalchemy-utils/docs/index.rst new file mode 100644 index 0000000..455e58e --- /dev/null +++ b/python-sqlalchemy-utils/docs/index.rst @@ -0,0 +1,25 @@ +SQLAlchemy-Utils +================ + + +SQLAlchemy-Utils provides custom data types and various utility functions for SQLAlchemy. + + +.. toctree:: + :maxdepth: 2 + + installation + listeners + data_types + range_data_types + aggregates + observers + internationalization + generic_relationship + database_helpers + foreign_key_helpers + orm_helpers + utility_classes + models + testing + license diff --git a/python-sqlalchemy-utils/docs/installation.rst b/python-sqlalchemy-utils/docs/installation.rst new file mode 100644 index 0000000..27fb3bf --- /dev/null +++ b/python-sqlalchemy-utils/docs/installation.rst @@ -0,0 +1,53 @@ +Installation +============ + +This part of the documentation covers the installation of SQLAlchemy-Utils. + +Supported platforms +------------------- + +SQLAlchemy-Utils has been tested against the following Python platforms. + +- cPython 2.6 +- cPython 2.7 +- cPython 3.3 + + +Installing an official release +------------------------------ + +You can install the most recent official SQLAlchemy-Utils version using +pip_:: + + pip install sqlalchemy-utils + +.. _pip: http://www.pip-installer.org/ + +Installing the development version +---------------------------------- + +To install the latest version of SQLAlchemy-Utils, you need first obtain a +copy of the source. You can do that by cloning the git_ repository:: + + git clone git://github.com/kvesteri/sqlalchemy-utils.git + +Then you can install the source distribution using the ``setup.py`` +script:: + + cd sqlalchemy-utils + python setup.py install + +.. _git: http://git-scm.org/ + +Checking the installation +------------------------- + +To check that SQLAlchemy-Utils has been properly installed, type ``python`` +from your shell. Then at the Python prompt, try to import SQLAlchemy-Utils, +and check the installed version: + +.. parsed-literal:: + + >>> import sqlalchemy_utils + >>> sqlalchemy_utils.__version__ + |release| diff --git a/python-sqlalchemy-utils/docs/internationalization.rst b/python-sqlalchemy-utils/docs/internationalization.rst new file mode 100644 index 0000000..6c9d400 --- /dev/null +++ b/python-sqlalchemy-utils/docs/internationalization.rst @@ -0,0 +1,153 @@ +Internationalization +==================== + +SQLAlchemy-Utils provides a way for modeling translatable models. Model is +translatable if one or more of its columns can be displayed in various languages. + +.. note:: + + The implementation is currently highly PostgreSQL specific since it needs + a dict-compatible column type (PostgreSQL HSTORE and JSON are such types). + If you want database-agnostic way of modeling i18n see `SQLAlchemy-i18n`_. + + +TranslationHybrid vs SQLAlchemy-i18n +------------------------------------ + +Compared to SQLAlchemy-i18n the TranslationHybrid has the following pros and cons: + +* Usually faster since no joins are needed for fetching the data +* Less magic +* Easier to understand data model +* Only PostgreSQL supported for now + + +Quickstart +---------- + +Let's say we have an Article model with translatable name and content. First we +need to define the TranslationHybrid. + +:: + + from sqlalchemy_utils import TranslationHybrid + + + # For testing purposes we define this as simple function which returns + # locale 'fi'. Usually you would define this function as something that + # returns the user's current locale. + def get_locale(): + return 'fi' + + + translation_hybrid = TranslationHybrid( + current_locale=get_locale, + default_locale='en' + ) + + +Then we can define the model.:: + + + from sqlalchemy import * + from sqlalchemy.dialects.postgresql import HSTORE + + + class Article(Base): + __tablename__ = 'article' + + id = Column(Integer, primary_key=True) + name_translations = Column(HSTORE) + content_translations = Column(HSTORE) + + name = translation_hybrid(name_translations) + content = translation_hybrid(content_translations) + + +Now we can start using our translatable model. By assigning things to +translatable hybrids you are assigning them to the locale returned by the +`current_locale`. +:: + + + article = Article(name='Joku artikkeli') + article.name_translations['fi'] # Joku artikkeli + article.name # Joku artikkeli + + +If you access the hybrid with a locale that doesn't exist the hybrid tries to +fetch a the locale returned by `default_locale`. +:: + + article = Article(name_translations={'en': 'Some article'}) + article.name # Some article + article.name_translations['fi'] = 'Joku artikkeli' + article.name # Joku artikkeli + + +Translation hybrids can also be used as expressions. +:: + + session.query(Article).filter(Article.name['en'] == 'Some article') + + +By default if no value is found for either current or default locale the +translation hybrid returns `None`. You can customize this value with `default_value` parameter +of translation_hybrid. In the following example we make translation hybrid fallback to empty string instead of `None`. + +:: + + translation_hybrid = TranslationHybrid( + current_locale=get_locale, + default_locale='en', + default_value='' + ) + + + class Article(Base): + __tablename__ = 'article' + + id = Column(Integer, primary_key=True) + name_translations = Column(HSTORE) + + name = translation_hybrid(name_translations, default) + + + Article().name # '' + + +Dynamic locales +--------------- + +Sometimes locales need to be dynamic. The following example illustrates how to setup +dynamic locales. + + +:: + + translation_hybrid = TranslationHybrid( + current_locale=get_locale, + default_locale=lambda obj: obj.locale, + ) + + + class Article(Base): + __tablename__ = 'article' + + id = Column(Integer, primary_key=True) + name_translations = Column(HSTORE) + + name = translation_hybrid(name_translations, default) + locale = Column(String) + + + article = Article(name_translations={'en': 'Some article'}) + session.add(article) + session.commit() + + article.name # Some article (even if current locale is other than 'en') + + + + +.. _SQLAlchemy-i18n: https://github.com/kvesteri/sqlalchemy-i18n diff --git a/python-sqlalchemy-utils/docs/license.rst b/python-sqlalchemy-utils/docs/license.rst new file mode 100644 index 0000000..7e6291f --- /dev/null +++ b/python-sqlalchemy-utils/docs/license.rst @@ -0,0 +1,4 @@ +License +======= + +.. include:: ../LICENSE diff --git a/python-sqlalchemy-utils/docs/listeners.rst b/python-sqlalchemy-utils/docs/listeners.rst new file mode 100644 index 0000000..275c138 --- /dev/null +++ b/python-sqlalchemy-utils/docs/listeners.rst @@ -0,0 +1,22 @@ +Listeners +========= + + +.. module:: sqlalchemy_utils.listeners + +Automatic data coercion +----------------------- + +.. autofunction:: force_auto_coercion + + +Instant defaults +---------------- + +.. autofunction:: force_instant_defaults + + +Many-to-many orphan deletion +---------------------------- + +.. autofunction:: auto_delete_orphans diff --git a/python-sqlalchemy-utils/docs/make.bat b/python-sqlalchemy-utils/docs/make.bat new file mode 100644 index 0000000..0298f86 --- /dev/null +++ b/python-sqlalchemy-utils/docs/make.bat @@ -0,0 +1,190 @@ +@ECHO OFF + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set BUILDDIR=_build +set ALLSPHINXOPTS=-d %BUILDDIR%/doctrees %SPHINXOPTS% . +set I18NSPHINXOPTS=%SPHINXOPTS% . +if NOT "%PAPER%" == "" ( + set ALLSPHINXOPTS=-D latex_paper_size=%PAPER% %ALLSPHINXOPTS% + set I18NSPHINXOPTS=-D latex_paper_size=%PAPER% %I18NSPHINXOPTS% +) + +if "%1" == "" goto help + +if "%1" == "help" ( + :help + echo.Please use `make ^` where ^ is one of + echo. html to make standalone HTML files + echo. dirhtml to make HTML files named index.html in directories + echo. singlehtml to make a single large HTML file + echo. pickle to make pickle files + echo. json to make JSON files + echo. htmlhelp to make HTML files and a HTML help project + echo. qthelp to make HTML files and a qthelp project + echo. devhelp to make HTML files and a Devhelp project + echo. epub to make an epub + echo. latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter + echo. text to make text files + echo. man to make manual pages + echo. texinfo to make Texinfo files + echo. gettext to make PO message catalogs + echo. changes to make an overview over all changed/added/deprecated items + echo. linkcheck to check all external links for integrity + echo. doctest to run all doctests embedded in the documentation if enabled + goto end +) + +if "%1" == "clean" ( + for /d %%i in (%BUILDDIR%\*) do rmdir /q /s %%i + del /q /s %BUILDDIR%\* + goto end +) + +if "%1" == "html" ( + %SPHINXBUILD% -b html %ALLSPHINXOPTS% %BUILDDIR%/html + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The HTML pages are in %BUILDDIR%/html. + goto end +) + +if "%1" == "dirhtml" ( + %SPHINXBUILD% -b dirhtml %ALLSPHINXOPTS% %BUILDDIR%/dirhtml + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The HTML pages are in %BUILDDIR%/dirhtml. + goto end +) + +if "%1" == "singlehtml" ( + %SPHINXBUILD% -b singlehtml %ALLSPHINXOPTS% %BUILDDIR%/singlehtml + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The HTML pages are in %BUILDDIR%/singlehtml. + goto end +) + +if "%1" == "pickle" ( + %SPHINXBUILD% -b pickle %ALLSPHINXOPTS% %BUILDDIR%/pickle + if errorlevel 1 exit /b 1 + echo. + echo.Build finished; now you can process the pickle files. + goto end +) + +if "%1" == "json" ( + %SPHINXBUILD% -b json %ALLSPHINXOPTS% %BUILDDIR%/json + if errorlevel 1 exit /b 1 + echo. + echo.Build finished; now you can process the JSON files. + goto end +) + +if "%1" == "htmlhelp" ( + %SPHINXBUILD% -b htmlhelp %ALLSPHINXOPTS% %BUILDDIR%/htmlhelp + if errorlevel 1 exit /b 1 + echo. + echo.Build finished; now you can run HTML Help Workshop with the ^ +.hhp project file in %BUILDDIR%/htmlhelp. + goto end +) + +if "%1" == "qthelp" ( + %SPHINXBUILD% -b qthelp %ALLSPHINXOPTS% %BUILDDIR%/qthelp + if errorlevel 1 exit /b 1 + echo. + echo.Build finished; now you can run "qcollectiongenerator" with the ^ +.qhcp project file in %BUILDDIR%/qthelp, like this: + echo.^> qcollectiongenerator %BUILDDIR%\qthelp\SQLAlchemy-Utils.qhcp + echo.To view the help file: + echo.^> assistant -collectionFile %BUILDDIR%\qthelp\SQLAlchemy-Utils.ghc + goto end +) + +if "%1" == "devhelp" ( + %SPHINXBUILD% -b devhelp %ALLSPHINXOPTS% %BUILDDIR%/devhelp + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. + goto end +) + +if "%1" == "epub" ( + %SPHINXBUILD% -b epub %ALLSPHINXOPTS% %BUILDDIR%/epub + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The epub file is in %BUILDDIR%/epub. + goto end +) + +if "%1" == "latex" ( + %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex + if errorlevel 1 exit /b 1 + echo. + echo.Build finished; the LaTeX files are in %BUILDDIR%/latex. + goto end +) + +if "%1" == "text" ( + %SPHINXBUILD% -b text %ALLSPHINXOPTS% %BUILDDIR%/text + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The text files are in %BUILDDIR%/text. + goto end +) + +if "%1" == "man" ( + %SPHINXBUILD% -b man %ALLSPHINXOPTS% %BUILDDIR%/man + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The manual pages are in %BUILDDIR%/man. + goto end +) + +if "%1" == "texinfo" ( + %SPHINXBUILD% -b texinfo %ALLSPHINXOPTS% %BUILDDIR%/texinfo + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The Texinfo files are in %BUILDDIR%/texinfo. + goto end +) + +if "%1" == "gettext" ( + %SPHINXBUILD% -b gettext %I18NSPHINXOPTS% %BUILDDIR%/locale + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The message catalogs are in %BUILDDIR%/locale. + goto end +) + +if "%1" == "changes" ( + %SPHINXBUILD% -b changes %ALLSPHINXOPTS% %BUILDDIR%/changes + if errorlevel 1 exit /b 1 + echo. + echo.The overview file is in %BUILDDIR%/changes. + goto end +) + +if "%1" == "linkcheck" ( + %SPHINXBUILD% -b linkcheck %ALLSPHINXOPTS% %BUILDDIR%/linkcheck + if errorlevel 1 exit /b 1 + echo. + echo.Link check complete; look for any errors in the above output ^ +or in %BUILDDIR%/linkcheck/output.txt. + goto end +) + +if "%1" == "doctest" ( + %SPHINXBUILD% -b doctest %ALLSPHINXOPTS% %BUILDDIR%/doctest + if errorlevel 1 exit /b 1 + echo. + echo.Testing of doctests in the sources finished, look at the ^ +results in %BUILDDIR%/doctest/output.txt. + goto end +) + +:end diff --git a/python-sqlalchemy-utils/docs/models.rst b/python-sqlalchemy-utils/docs/models.rst new file mode 100644 index 0000000..ae7e0ed --- /dev/null +++ b/python-sqlalchemy-utils/docs/models.rst @@ -0,0 +1,10 @@ +Model mixins +============ + + +Timestamp +--------- + +.. module:: sqlalchemy_utils.models + +.. autoclass:: Timestamp diff --git a/python-sqlalchemy-utils/docs/observers.rst b/python-sqlalchemy-utils/docs/observers.rst new file mode 100644 index 0000000..33a12a9 --- /dev/null +++ b/python-sqlalchemy-utils/docs/observers.rst @@ -0,0 +1,6 @@ +Observers +========= + +.. automodule:: sqlalchemy_utils.observer + +.. autofunction:: observes diff --git a/python-sqlalchemy-utils/docs/orm_helpers.rst b/python-sqlalchemy-utils/docs/orm_helpers.rst new file mode 100644 index 0000000..eac570e --- /dev/null +++ b/python-sqlalchemy-utils/docs/orm_helpers.rst @@ -0,0 +1,118 @@ +ORM helpers +=========== + +.. module:: sqlalchemy_utils.functions + + +escape_like +----------- + +.. autofunction:: escape_like + + +get_bind +-------- + +.. autofunction:: get_bind + + +get_class_by_table +------------------ + +.. autofunction:: get_class_by_table + + +get_column_key +-------------- + +.. autofunction:: get_column_key + + +get_columns +----------- + +.. autofunction:: get_columns + + +get_declarative_base +-------------------- + +.. autofunction:: get_declarative_base + + +get_hybrid_properties +--------------------- + +.. autofunction:: get_hybrid_properties + + +get_mapper +---------- + +.. autofunction:: get_mapper + + +get_query_entities +------------------ + +.. autofunction:: get_query_entities + + +get_primary_keys +---------------- + +.. autofunction:: get_primary_keys + + +get_tables +---------- + +.. autofunction:: get_tables + + +get_type +-------- + +.. autofunction:: get_type + + +has_changes +----------- + +.. autofunction:: has_changes + + +identity +-------- + +.. autofunction:: identity + + +is_loaded +--------- + +.. autofunction:: is_loaded + + +make_order_by_deterministic +--------------------------- + +.. autofunction:: make_order_by_deterministic + + +naturally_equivalent +-------------------- + +.. autofunction:: naturally_equivalent + + +quote +----- + +.. autofunction:: quote + + +sort_query +---------- + +.. autofunction:: sort_query diff --git a/python-sqlalchemy-utils/docs/range_data_types.rst b/python-sqlalchemy-utils/docs/range_data_types.rst new file mode 100644 index 0000000..e57abe1 --- /dev/null +++ b/python-sqlalchemy-utils/docs/range_data_types.rst @@ -0,0 +1,37 @@ +Range data types +================ + +.. automodule:: sqlalchemy_utils.types.range + + + + +DateRangeType +------------- + +.. autoclass:: DateRangeType + + +DateTimeRangeType +----------------- + +.. autoclass:: DateTimeRangeType + + +IntRangeType +------------ + +.. autoclass:: IntRangeType + + +NumericRangeType +---------------- + +.. autoclass:: NumericRangeType + + +RangeComparator +--------------- + +.. autoclass:: RangeComparator + :members: diff --git a/python-sqlalchemy-utils/docs/testing.rst b/python-sqlalchemy-utils/docs/testing.rst new file mode 100644 index 0000000..024a297 --- /dev/null +++ b/python-sqlalchemy-utils/docs/testing.rst @@ -0,0 +1,30 @@ +Testing +======= + +.. automodule:: sqlalchemy_utils.asserts + + +assert_min_value +---------------- + +.. autofunction:: assert_min_value + +assert_max_length +----------------- + +.. autofunction:: assert_max_length + +assert_max_value +---------------- + +.. autofunction:: assert_max_value + +assert_nullable +--------------- + +.. autofunction:: assert_nullable + +assert_non_nullable +------------------- + +.. autofunction:: assert_non_nullable diff --git a/python-sqlalchemy-utils/docs/utility_classes.rst b/python-sqlalchemy-utils/docs/utility_classes.rst new file mode 100644 index 0000000..c40f008 --- /dev/null +++ b/python-sqlalchemy-utils/docs/utility_classes.rst @@ -0,0 +1,13 @@ +Utility classes +=============== + +QueryChain +---------- + +.. automodule:: sqlalchemy_utils.query_chain + +API +--- + +.. autoclass:: QueryChain + :members: diff --git a/python-sqlalchemy-utils/setup.py b/python-sqlalchemy-utils/setup.py new file mode 100644 index 0000000..5727bd0 --- /dev/null +++ b/python-sqlalchemy-utils/setup.py @@ -0,0 +1,109 @@ +""" +SQLAlchemy-Utils +---------------- + +Various utility functions and custom data types for SQLAlchemy. +""" +from setuptools import setup, find_packages +import os +import re +import sys + + +HERE = os.path.dirname(os.path.abspath(__file__)) +PY3 = sys.version_info[0] == 3 + + +def get_version(): + filename = os.path.join(HERE, 'sqlalchemy_utils', '__init__.py') + with open(filename) as f: + contents = f.read() + pattern = r"^__version__ = '(.*?)'$" + return re.search(pattern, contents, re.MULTILINE).group(1) + + +PY3 = sys.version_info[0] == 3 + + +extras_require = { + 'test': [ + 'pytest==2.3.5', + 'Pygments>=1.2', + 'Jinja2>=2.3', + 'docutils>=0.10', + 'flexmock>=0.9.7', + 'psycopg2>=2.5.1', + 'pytz>=2014.2', + 'python-dateutil>=2.2', + 'pymysql', + 'flake8>=2.4.0', + 'isort==3.9.6', + 'natsort==3.5.6', + ], + 'anyjson': ['anyjson>=0.3.3'], + 'babel': ['Babel>=1.3'], + 'arrow': ['arrow>=0.3.4'], + 'intervals': ['intervals>=0.2.4'], + 'phone': ['phonenumbers>=5.9.2'], + 'password': ['passlib >= 1.6, < 2.0'], + 'color': ['colour>=0.0.4'], + 'ipaddress': ['ipaddr'] if not PY3 else [], + 'enum': ['enum34'] if sys.version_info < (3, 4) else [], + 'timezone': ['python-dateutil'], + 'url': ['furl >= 0.4.1'], + 'encrypted': ['cryptography>=0.6'] +} + + +# Add all optional dependencies to testing requirements. +test_all = [] +for name, requirements in extras_require.items(): + test_all += requirements +extras_require['test_all'] = test_all + + +setup( + name='SQLAlchemy-Utils', + version=get_version(), + url='https://github.com/kvesteri/sqlalchemy-utils', + license='BSD', + author='Konsta Vesterinen, Ryan Leckey, Janne Vanhala, Vesa Uimonen', + author_email='konsta@fastmonkeys.com', + description=( + 'Various utility functions for SQLAlchemy.' + ), + long_description=__doc__, + packages=find_packages('.', exclude=['tests', 'tests.*']), + zip_safe=False, + include_package_data=True, + platforms='any', + dependency_links=[ + # 5.6b1 only supports python 3.x / pending release + 'git+git://github.com/daviddrysdale/python-phonenumbers.git@python3' + '#egg=phonenumbers3k-5.6b1', + ], + install_requires=[ + 'six', + 'SQLAlchemy>=1.0', + 'total_ordering>=0.1' + if sys.version_info[0] == 2 and sys.version_info[1] < 7 else '', + 'ordereddict>=1.1' + if sys.version_info[0] == 2 and sys.version_info[1] < 7 else '', + ], + extras_require=extras_require, + classifiers=[ + 'Environment :: Web Environment', + 'Intended Audience :: Developers', + 'License :: OSI Approved :: BSD License', + 'Operating System :: OS Independent', + 'Programming Language :: Python', + 'Programming Language :: Python :: 2', + 'Programming Language :: Python :: 2.6', + 'Programming Language :: Python :: 2.7', + 'Programming Language :: Python :: 3', + 'Programming Language :: Python :: 3.3', + 'Programming Language :: Python :: 3.4', + 'Topic :: Internet :: WWW/HTTP :: Dynamic Content', + 'Topic :: Software Development :: Libraries :: Python Modules' + ] +) diff --git a/python-sqlalchemy-utils/sqlalchemy_utils/__init__.py b/python-sqlalchemy-utils/sqlalchemy_utils/__init__.py new file mode 100644 index 0000000..a2cc96c --- /dev/null +++ b/python-sqlalchemy-utils/sqlalchemy_utils/__init__.py @@ -0,0 +1,95 @@ +from .aggregates import aggregated # noqa +from .asserts import ( # noqa + assert_max_length, + assert_max_value, + assert_min_value, + assert_non_nullable, + assert_nullable +) +from .exceptions import ImproperlyConfigured # noqa +from .expression_parser import ExpressionParser # noqa +from .expressions import Asterisk, row_to_json # noqa +from .functions import ( # noqa + analyze, + create_database, + create_mock_engine, + database_exists, + dependent_objects, + drop_database, + escape_like, + get_bind, + get_class_by_table, + get_column_key, + get_columns, + get_declarative_base, + get_hybrid_properties, + get_mapper, + get_primary_keys, + get_query_entities, + get_referencing_foreign_keys, + get_tables, + get_type, + group_foreign_keys, + has_changes, + has_index, + has_unique_index, + identity, + is_loaded, + json_sql, + merge_references, + mock_engine, + naturally_equivalent, + render_expression, + render_statement, + sort_query, + table_name +) +from .generic import generic_relationship # noqa +from .i18n import TranslationHybrid # noqa +from .listeners import ( # noqa + auto_delete_orphans, + coercion_listener, + force_auto_coercion, + force_instant_defaults +) +from .models import Timestamp # noqa +from .observer import observes # noqa +from .primitives import Country, Currency, WeekDay, WeekDays # noqa +from .proxy_dict import proxy_dict, ProxyDict # noqa +from .query_chain import QueryChain # noqa +from .types import ( # noqa + ArrowType, + Choice, + ChoiceType, + ColorType, + CompositeArray, + CompositeType, + CountryType, + CurrencyType, + DateRangeType, + DateTimeRangeType, + EmailType, + EncryptedType, + instrumented_list, + InstrumentedList, + IntRangeType, + IPAddressType, + JSONType, + LocaleType, + NumericRangeType, + Password, + PasswordType, + PhoneNumber, + PhoneNumberType, + register_composites, + remove_composite_listeners, + ScalarListException, + ScalarListType, + TimezoneType, + TSVectorType, + URLType, + UUIDType, + WeekDaysType +) + +__version__ = '0.30.12' diff --git a/python-sqlalchemy-utils/sqlalchemy_utils/aggregates.py b/python-sqlalchemy-utils/sqlalchemy_utils/aggregates.py new file mode 100644 index 0000000..c02b1f2 --- /dev/null +++ b/python-sqlalchemy-utils/sqlalchemy_utils/aggregates.py @@ -0,0 +1,576 @@ +""" +SQLAlchemy-Utils provides way of automatically calculating aggregate values of +related models and saving them to parent model. + +This solution is inspired by RoR counter cache, +`counter_culture`_ and `stackoverflow reply by Michael Bayer`_. + +Why? +---- + +Many times you may have situations where you need to calculate dynamically some +aggregate value for given model. Some simple examples include: + +- Number of products in a catalog +- Average rating for movie +- Latest forum post +- Total price of orders for given customer + +Now all these aggregates can be elegantly implemented with SQLAlchemy +column_property_ function. However when your data grows calculating these +values on the fly might start to hurt the performance of your application. The +more aggregates you are using the more performance penalty you get. + +This module provides way of calculating these values automatically and +efficiently at the time of modification rather than on the fly. + + +Features +-------- + +* Automatically updates aggregate columns when aggregated values change +* Supports aggregate values through arbitrary number levels of relations +* Highly optimized: uses single query per transaction per aggregate column +* Aggregated columns can be of any data type and use any selectable scalar + expression + + +.. _column_property: + http://docs.sqlalchemy.org/en/latest/orm/mapper_config.html#using-column-property +.. _counter_culture: https://github.com/magnusvk/counter_culture +.. _stackoverflow reply by Michael Bayer: + http://stackoverflow.com/questions/13693872/ + + +Simple aggregates +----------------- + +:: + + from sqlalchemy_utils import aggregated + + + class Thread(Base): + __tablename__ = 'thread' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + @aggregated('comments', sa.Column(sa.Integer)) + def comment_count(self): + return sa.func.count('1') + + comments = sa.orm.relationship( + 'Comment', + backref='thread' + ) + + + class Comment(Base): + __tablename__ = 'comment' + id = sa.Column(sa.Integer, primary_key=True) + content = sa.Column(sa.UnicodeText) + thread_id = sa.Column(sa.Integer, sa.ForeignKey(Thread.id)) + + + thread = Thread(name=u'SQLAlchemy development') + thread.comments.append(Comment(u'Going good!')) + thread.comments.append(Comment(u'Great new features!')) + + session.add(thread) + session.commit() + + thread.comment_count # 2 + + + +Custom aggregate expressions +---------------------------- + +Aggregate expression can be virtually any SQL expression not just a simple +function taking one parameter. You can try things such as subqueries and +different kinds of functions. + +In the following example we have a Catalog of products where each catalog +knows the net worth of its products. + +:: + + + from sqlalchemy_utils import aggregated + + + class Catalog(Base): + __tablename__ = 'catalog' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + @aggregated('products', sa.Column(sa.Integer)) + def net_worth(self): + return sa.func.sum(Product.price) + + products = sa.orm.relationship('Product') + + + class Product(Base): + __tablename__ = 'product' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + price = sa.Column(sa.Numeric) + + catalog_id = sa.Column(sa.Integer, sa.ForeignKey(Catalog.id)) + + +Now the net_worth column of Catalog model will be automatically whenever: + +* A new product is added to the catalog +* A product is deleted from the catalog +* The price of catalog product is changed + + +:: + + + from decimal import Decimal + + + product1 = Product(name='Some product', price=Decimal(1000)) + product2 = Product(name='Some other product', price=Decimal(500)) + + + catalog = Catalog( + name=u'My first catalog', + products=[ + product1, + product2 + ] + ) + session.add(catalog) + session.commit() + + session.refresh(catalog) + catalog.net_worth # 1500 + + session.delete(product2) + session.commit() + session.refresh(catalog) + + catalog.net_worth # 1000 + + product1.price = 2000 + session.commit() + session.refresh(catalog) + + catalog.net_worth # 2000 + + + + +Multiple aggregates per class +----------------------------- + +Sometimes you may need to define multiple aggregate values for same class. If +you need to define lots of relationships pointing to same class, remember to +define the relationships as viewonly when possible. + + +:: + + + from sqlalchemy_utils import aggregated + + + class Customer(Base): + __tablename__ = 'customer' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + @aggregated('orders', sa.Column(sa.Integer)) + def orders_sum(self): + return sa.func.sum(Order.price) + + @aggregated('invoiced_orders', sa.Column(sa.Integer)) + def invoiced_orders_sum(self): + return sa.func.sum(Order.price) + + orders = sa.orm.relationship('Order') + + invoiced_orders = sa.orm.relationship( + 'Order', + primaryjoin= + 'sa.and_(Order.customer_id == Customer.id, Order.invoiced)', + viewonly=True + ) + + + class Order(Base): + __tablename__ = 'order' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + price = sa.Column(sa.Numeric) + invoiced = sa.Column(sa.Boolean, default=False) + customer_id = sa.Column(sa.Integer, sa.ForeignKey(Customer.id)) + + +Many-to-Many aggregates +----------------------- + +Aggregate expressions also support many-to-many relationships. The usual use +scenarios includes things such as: + +1. Friend count of a user +2. Group count where given user belongs to + +:: + + + user_group = sa.Table('user_group', Base.metadata, + sa.Column('user_id', sa.Integer, sa.ForeignKey('user.id')), + sa.Column('group_id', sa.Integer, sa.ForeignKey('group.id')) + ) + + + class User(Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + @aggregated('groups', sa.Column(sa.Integer, default=0)) + def group_count(self): + return sa.func.count('1') + + groups = sa.orm.relationship( + 'Group', + backref='users', + secondary=user_group + ) + + + class Group(Base): + __tablename__ = 'group' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + + + user = User(name=u'John Matrix') + user.groups = [Group(name=u'Group A'), Group(name=u'Group B')] + + session.add(user) + session.commit() + + session.refresh(user) + user.group_count # 2 + + +Multi-level aggregates +---------------------- + +Aggregates can span accross multiple relationships. In the following example +each Catalog has a net_worth which is the sum of all products in all +categories. + + +:: + + + from sqlalchemy_utils import aggregated + + + class Catalog(Base): + __tablename__ = 'catalog' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + @aggregated('categories.products', sa.Column(sa.Integer)) + def net_worth(self): + return sa.func.sum(Product.price) + + categories = sa.orm.relationship('Product') + + + class Category(Base): + __tablename__ = 'category' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + catalog_id = sa.Column(sa.Integer, sa.ForeignKey(Catalog.id)) + + products = sa.orm.relationship('Product') + + + class Product(Base): + __tablename__ = 'product' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + price = sa.Column(sa.Numeric) + + category_id = sa.Column(sa.Integer, sa.ForeignKey(Category.id)) + + +Examples +-------- + +Average movie rating +^^^^^^^^^^^^^^^^^^^^ + +:: + + + from sqlalchemy_utils import aggregated + + + class Movie(Base): + __tablename__ = 'movie' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + @aggregated('ratings', sa.Column(sa.Numeric)) + def avg_rating(self): + return sa.func.avg(Rating.stars) + + ratings = sa.orm.relationship('Rating') + + + class Rating(Base): + __tablename__ = 'rating' + id = sa.Column(sa.Integer, primary_key=True) + stars = sa.Column(sa.Integer) + + movie_id = sa.Column(sa.Integer, sa.ForeignKey(Movie.id)) + + + movie = Movie('Terminator 2') + movie.ratings.append(Rating(stars=5)) + movie.ratings.append(Rating(stars=4)) + movie.ratings.append(Rating(stars=3)) + session.add(movie) + session.commit() + + movie.avg_rating # 4 + + + +TODO +---- + +* Special consideration should be given to `deadlocks`_. + + +.. _deadlocks: + http://mina.naguib.ca/blog/2010/11/22/postgresql-foreign-key-deadlocks.html + +""" + + +from collections import defaultdict +from weakref import WeakKeyDictionary + +import six +import sqlalchemy as sa +from sqlalchemy.ext.declarative import declared_attr + +from .functions.orm import get_column_key +from .relationships import chained_join, select_aggregate + +try: + # SQLAlchemy 0.9 + from sqlalchemy.sql.functions import _FunctionGenerator +except ImportError: + # SQLAlchemy 0.8 + from sqlalchemy.sql.expression import _FunctionGenerator + + +aggregated_attrs = WeakKeyDictionary(defaultdict(list)) + + +class AggregatedAttribute(declared_attr): + def __init__( + self, + fget, + relationship, + column, + *args, + **kwargs + ): + super(AggregatedAttribute, self).__init__(fget, *args, **kwargs) + self.__doc__ = fget.__doc__ + self.column = column + self.relationship = relationship + + def __get__(desc, self, cls): + value = (desc.fget, desc.relationship, desc.column) + if cls not in aggregated_attrs: + aggregated_attrs[cls] = [value] + else: + aggregated_attrs[cls].append(value) + return desc.column + + +def local_condition(prop, objects): + pairs = prop.local_remote_pairs + if prop.secondary is not None: + parent_column = pairs[1][0] + fetched_column = pairs[1][0] + else: + parent_column = pairs[0][0] + fetched_column = pairs[0][1] + + key = get_column_key(prop.mapper, fetched_column) + + values = [] + for obj in objects: + try: + values.append(getattr(obj, key)) + except sa.orm.exc.ObjectDeletedError: + pass + + if values: + return parent_column.in_(values) + + +def aggregate_expression(expr, class_): + if isinstance(expr, sa.sql.visitors.Visitable): + return expr + elif isinstance(expr, _FunctionGenerator): + return expr(sa.sql.text('1')) + else: + return expr(class_) + + +class AggregatedValue(object): + def __init__(self, class_, attr, relationships, expr): + self.class_ = class_ + self.attr = attr + self.relationships = relationships + self.expr = aggregate_expression(expr, class_) + + @property + def aggregate_query(self): + query = select_aggregate(self.expr, self.relationships) + + return query.correlate(self.class_).as_scalar() + + def update_query(self, objects): + table = self.class_.__table__ + query = table.update().values( + {self.attr: self.aggregate_query} + ) + if len(self.relationships) == 1: + prop = self.relationships[-1].property + condition = local_condition(prop, objects) + if condition is not None: + return query.where(condition) + else: + # Builds query such as: + # + # UPDATE catalog SET product_count = (aggregate_query) + # WHERE id IN ( + # SELECT catalog_id + # FROM category + # INNER JOIN sub_category + # ON category.id = sub_category.category_id + # WHERE sub_category.id IN (product_sub_category_ids) + # ) + property_ = self.relationships[-1].property + remote_pairs = property_.local_remote_pairs + local = remote_pairs[0][0] + remote = remote_pairs[0][1] + condition = local_condition( + self.relationships[0].property, + objects + ) + if condition is not None: + return query.where( + local.in_( + sa.select( + [remote], + from_obj=[ + chained_join(*reversed(self.relationships)) + ] + ).where( + condition + ) + ) + ) + + +class AggregationManager(object): + def __init__(self): + self.reset() + + def reset(self): + self.generator_registry = defaultdict(list) + + def register_listeners(self): + sa.event.listen( + sa.orm.mapper, + 'after_configured', + self.update_generator_registry + ) + sa.event.listen( + sa.orm.session.Session, + 'after_flush', + self.construct_aggregate_queries + ) + + def update_generator_registry(self): + for class_, attrs in six.iteritems(aggregated_attrs): + for expr, relationship, column in attrs: + relationships = [] + rel_class = class_ + + for path_name in relationship.split('.'): + rel = getattr(rel_class, path_name) + relationships.append(rel) + rel_class = rel.mapper.class_ + + self.generator_registry[rel_class].append( + AggregatedValue( + class_=class_, + attr=column, + relationships=list(reversed(relationships)), + expr=expr(class_) + ) + ) + + def construct_aggregate_queries(self, session, ctx): + object_dict = defaultdict(list) + for obj in session: + class_ = obj.__class__ + if class_ in self.generator_registry: + object_dict[class_].append(obj) + + for class_, objects in six.iteritems(object_dict): + for aggregate_value in self.generator_registry[class_]: + query = aggregate_value.update_query(objects) + if query is not None: + session.execute(query) + + +manager = AggregationManager() +manager.register_listeners() + + +def aggregated( + relationship, + column +): + """ + Decorator that generates an aggregated attribute. The decorated function + should return an aggregate select expression. + + :param relationship: + Defines the relationship of which the aggregate is calculated from. + The class needs to have given relationship in order to calculate the + aggregate. + :param column: + SQLAlchemy Column object. The column definition of this aggregate + attribute. + """ + def wraps(func): + return AggregatedAttribute( + func, + relationship, + column + ) + return wraps diff --git a/python-sqlalchemy-utils/sqlalchemy_utils/asserts.py b/python-sqlalchemy-utils/sqlalchemy_utils/asserts.py new file mode 100644 index 0000000..4800c01 --- /dev/null +++ b/python-sqlalchemy-utils/sqlalchemy_utils/asserts.py @@ -0,0 +1,182 @@ +""" +The functions in this module can be used for testing that the constraints of +your models. Each assert function runs SQL UPDATEs that check for the existence +of given constraint. Consider the following model:: + + + class User(Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.String(200), nullable=True) + email = sa.Column(sa.String(255), nullable=False) + + + user = User(name='John Doe', email='john@example.com') + session.add(user) + session.commit() + + +We can easily test the constraints by assert_* functions:: + + + from sqlalchemy_utils import ( + assert_nullable, + assert_non_nullable, + assert_max_length + ) + + assert_nullable(user, 'name') + assert_non_nullable(user, 'email') + assert_max_length(user, 'name', 200) + + # raises AssertionError because the max length of email is 255 + assert_max_length(user, 'email', 300) +""" +from decimal import Decimal + +import sqlalchemy as sa +from sqlalchemy.dialects.postgresql import ARRAY +from sqlalchemy.exc import DataError, IntegrityError + + +def _update_field(obj, field, value): + session = sa.orm.object_session(obj) + table = sa.inspect(obj.__class__).columns[field].table + query = table.update().values(**{field: value}) + session.execute(query) + session.flush() + + +def _expect_successful_update(obj, field, value, reraise_exc): + try: + _update_field(obj, field, value) + except (reraise_exc) as e: + session = sa.orm.object_session(obj) + session.rollback() + assert False, str(e) + + +def _expect_failing_update(obj, field, value, expected_exc): + try: + _update_field(obj, field, value) + except expected_exc: + pass + else: + raise AssertionError('Expected update to raise %s' % expected_exc) + finally: + session = sa.orm.object_session(obj) + session.rollback() + + +def _repeated_value(type_): + if isinstance(type_, ARRAY): + if isinstance(type_.item_type, sa.Integer): + return [0] + elif isinstance(type_.item_type, sa.String): + return [u'a'] + elif isinstance(type_.item_type, sa.Numeric): + return [Decimal('0')] + else: + raise TypeError('Unknown array item type') + else: + return u'a' + + +def _expected_exception(type_): + if isinstance(type_, ARRAY): + return IntegrityError + else: + return DataError + + +def assert_nullable(obj, column): + """ + Assert that given column is nullable. This is checked by running an SQL + update that assigns given column as None. + + :param obj: SQLAlchemy declarative model object + :param column: Name of the column + """ + _expect_successful_update(obj, column, None, IntegrityError) + + +def assert_non_nullable(obj, column): + """ + Assert that given column is not nullable. This is checked by running an SQL + update that assigns given column as None. + + :param obj: SQLAlchemy declarative model object + :param column: Name of the column + """ + _expect_failing_update(obj, column, None, IntegrityError) + + +def assert_max_length(obj, column, max_length): + """ + Assert that the given column is of given max length. This function supports + string typed columns as well as PostgreSQL array typed columns. + + In the following example we add a check constraint that user can have a + maximum of 5 favorite colors and then test this.:: + + + class User(Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + favorite_colors = sa.Column(ARRAY(sa.String), nullable=False) + __table_args__ = ( + sa.CheckConstraint( + sa.func.array_length(favorite_colors, 1) <= 5 + ) + ) + + + user = User(name='John Doe', favorite_colors=['red', 'blue']) + session.add(user) + session.commit() + + + assert_max_length(user, 'favorite_colors', 5) + + + :param obj: SQLAlchemy declarative model object + :param column: Name of the column + :param max_length: Maximum length of given column + """ + type_ = sa.inspect(obj.__class__).columns[column].type + _expect_successful_update( + obj, + column, + _repeated_value(type_) * max_length, + _expected_exception(type_) + ) + _expect_failing_update( + obj, + column, + _repeated_value(type_) * (max_length + 1), + _expected_exception(type_) + ) + + +def assert_min_value(obj, column, min_value): + """ + Assert that the given column must have a minimum value of `min_value`. + + :param obj: SQLAlchemy declarative model object + :param column: Name of the column + :param min_value: The minimum allowed value for given column + """ + _expect_successful_update(obj, column, min_value, IntegrityError) + _expect_failing_update(obj, column, min_value - 1, IntegrityError) + + +def assert_max_value(obj, column, min_value): + """ + Assert that the given column must have a minimum value of `max_value`. + + :param obj: SQLAlchemy declarative model object + :param column: Name of the column + :param max_value: The maximum allowed value for given column + """ + _expect_successful_update(obj, column, min_value, IntegrityError) + _expect_failing_update(obj, column, min_value + 1, IntegrityError) diff --git a/python-sqlalchemy-utils/sqlalchemy_utils/exceptions.py b/python-sqlalchemy-utils/sqlalchemy_utils/exceptions.py new file mode 100644 index 0000000..2d84f14 --- /dev/null +++ b/python-sqlalchemy-utils/sqlalchemy_utils/exceptions.py @@ -0,0 +1,10 @@ +""" +Global SQLAlchemy-Utils exception classes. +""" + + +class ImproperlyConfigured(Exception): + """ + SQLAlchemy-Utils is improperly configured; normally due to usage of + a utility that depends on a missing library. + """ diff --git a/python-sqlalchemy-utils/sqlalchemy_utils/expression_parser.py b/python-sqlalchemy-utils/sqlalchemy_utils/expression_parser.py new file mode 100644 index 0000000..074cc61 --- /dev/null +++ b/python-sqlalchemy-utils/sqlalchemy_utils/expression_parser.py @@ -0,0 +1,145 @@ +try: + from collections import OrderedDict +except ImportError: + from ordereddict import OrderedDict + +import six +import sqlalchemy as sa +from sqlalchemy.orm.attributes import InstrumentedAttribute +from sqlalchemy.sql.annotation import AnnotatedColumn +from sqlalchemy.sql.elements import ( + Case, + ClauseList, + False_, + Grouping, + Label, + Null, + True_, + Tuple +) +from sqlalchemy.sql.expression import ( + BinaryExpression, + BindParameter, + BooleanClauseList, + Cast, + UnaryExpression +) + + +class ExpressionParser(object): + parsers = OrderedDict(( + (BinaryExpression, 'binary_expression'), + (BooleanClauseList, 'boolean_expression'), + (UnaryExpression, 'unary_expression'), + (sa.Column, 'column'), + (AnnotatedColumn, 'column'), + (BindParameter, 'bind_parameter'), + (False_, 'false'), + (True_, 'true'), + (Grouping, 'grouping'), + (ClauseList, 'clause_list'), + (Label, 'label'), + (Cast, 'cast'), + (Case, 'case'), + (Tuple, 'tuple'), + (Null, 'null'), + (InstrumentedAttribute, 'instrumented_attribute') + )) + + def expression(self, expr): + if expr is None: + return + for class_, parser in self.parsers.items(): + if isinstance(expr, class_): + return getattr(self, parser)(expr) + raise Exception( + 'Unknown expression type %s' % expr.__class__.__name__ + ) + + def instrumented_attribute(self, expr): + return expr + + def null(self, expr): + return expr + + def tuple(self, expr): + return expr.__class__( + *map(self.expression, expr.clauses), + type_=expr.type + ) + + def clause_list(self, expr): + return expr.__class__( + *map(self.expression, expr.clauses), + group=expr.group, + group_contents=expr.group_contents, + operator=expr.operator + ) + + def label(self, expr): + return expr.__class__( + name=expr.name, + element=self.expression(expr._element), + type_=expr.type + ) + + def cast(self, expr): + return expr.__class__( + expression=self.expression(expr.clause), + type_=expr.type + ) + + def case(self, expr): + return expr.__class__( + whens=[ + tuple(self.expression(x) for x in when) for when in expr.whens + ], + value=self.expression(expr.value), + else_=self.expression(expr.else_) + ) + + def grouping(self, expr): + return expr.__class__(self.expression(expr.element)) + + def true(self, expr): + return expr + + def false(self, expr): + return expr + + def process_table(self, table): + return table + + def column(self, column): + table = self.process_table(column.table) + return table.c[column.name] + + def unary_expression(self, expr): + return expr.operator(self.expression(expr.element)) + + def bind_parameter(self, expr): + # somehow bind parameters passed as unicode are converted to + # ascii strings along the way, force convert them back to avoid + # sqlalchemy unicode warnings + if isinstance(expr.type, sa.Unicode): + expr.value = six.text_type(expr.value) + return expr + + def binary_expression(self, expr): + return expr.__class__( + left=self.expression(expr.left), + right=self.expression(expr.right), + operator=expr.operator, + type_=expr.type, + negate=expr.negate, + modifiers=expr.modifiers.copy() + ) + + def boolean_expression(self, expr): + return expr.operator(*[ + self.expression(child_expr) + for child_expr in expr.get_children() + ]) + + def __call__(self, expr): + return self.expression(expr) diff --git a/python-sqlalchemy-utils/sqlalchemy_utils/expressions.py b/python-sqlalchemy-utils/sqlalchemy_utils/expressions.py new file mode 100644 index 0000000..1489a8f --- /dev/null +++ b/python-sqlalchemy-utils/sqlalchemy_utils/expressions.py @@ -0,0 +1,143 @@ +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql +from sqlalchemy.ext.compiler import compiles +from sqlalchemy.sql.expression import ( + _literal_as_text, + ClauseElement, + ColumnElement, + Executable, + FunctionElement +) +from sqlalchemy.sql.functions import GenericFunction + +from sqlalchemy_utils.functions.orm import quote + + +class explain(Executable, ClauseElement): + """ + Define EXPLAIN element. + + http://www.postgresql.org/docs/devel/static/sql-explain.html + """ + def __init__( + self, + stmt, + analyze=False, + verbose=False, + costs=True, + buffers=False, + timing=True, + format='text' + ): + self.statement = _literal_as_text(stmt) + self.analyze = analyze + self.verbose = verbose + self.costs = costs + self.buffers = buffers + self.timing = timing + self.format = format + + +class explain_analyze(explain): + def __init__(self, stmt, **kwargs): + super(explain_analyze, self).__init__( + stmt, + analyze=True, + **kwargs + ) + + +@compiles(explain, 'postgresql') +def pg_explain(element, compiler, **kw): + text = "EXPLAIN " + options = [] + if element.analyze: + options.append('ANALYZE true') + if not element.timing: + options.append('TIMING false') + if element.buffers: + options.append('BUFFERS true') + if element.format != 'text': + options.append('FORMAT %s' % element.format) + if element.verbose: + options.append('VERBOSE true') + if not element.costs: + options.append('COSTS false') + if options: + text += '(%s) ' % ', '.join(options) + text += compiler.process(element.statement) + return text + + +class array_get(FunctionElement): + name = 'array_get' + + +@compiles(array_get) +def compile_array_get(element, compiler, **kw): + args = list(element.clauses) + if len(args) != 2: + raise Exception( + "Function 'array_get' expects two arguments (%d given)." % + len(args) + ) + + if not hasattr(args[1], 'value') or not isinstance(args[1].value, int): + raise Exception( + "Second argument should be an integer." + ) + return '(%s)[%s]' % ( + compiler.process(args[0]), + sa.text(str(args[1].value + 1)) + ) + + +class row_to_json(GenericFunction): + name = 'row_to_json' + type = postgresql.JSON + + +@compiles(row_to_json, 'postgresql') +def compile_row_to_json(element, compiler, **kw): + return "%s(%s)" % (element.name, compiler.process(element.clauses)) + + +class json_array_length(GenericFunction): + name = 'json_array_length' + type = sa.Integer + + +@compiles(json_array_length, 'postgresql') +def compile_json_array_length(element, compiler, **kw): + return "%s(%s)" % (element.name, compiler.process(element.clauses)) + + +class array_agg(GenericFunction): + name = 'array_agg' + type = postgresql.ARRAY + + def __init__(self, arg, default=None, **kw): + self.type = postgresql.ARRAY(arg.type) + self.default = default + GenericFunction.__init__(self, arg, **kw) + + +@compiles(array_agg, 'postgresql') +def compile_array_agg(element, compiler, **kw): + compiled = "%s(%s)" % (element.name, compiler.process(element.clauses)) + if element.default is None: + return compiled + return str(sa.func.coalesce( + sa.text(compiled), + sa.cast(postgresql.array(element.default), element.type) + ).compile(compiler)) + + +class Asterisk(ColumnElement): + def __init__(self, selectable): + self.selectable = selectable + + +@compiles(Asterisk) +def compile_asterisk(element, compiler, **kw): + return '%s.*' % quote(compiler.dialect, element.selectable.name) diff --git a/python-sqlalchemy-utils/sqlalchemy_utils/functions/__init__.py b/python-sqlalchemy-utils/sqlalchemy_utils/functions/__init__.py new file mode 100644 index 0000000..a8409df --- /dev/null +++ b/python-sqlalchemy-utils/sqlalchemy_utils/functions/__init__.py @@ -0,0 +1,46 @@ +from .database import ( # noqa + analyze, + create_database, + database_exists, + drop_database, + escape_like, + has_index, + has_unique_index, + is_auto_assigned_date_column, + json_sql +) +from .foreign_keys import ( # noqa + dependent_objects, + get_referencing_foreign_keys, + group_foreign_keys, + is_indexed_foreign_key, + merge_references, + non_indexed_foreign_keys +) +from .mock import create_mock_engine, mock_engine # noqa +from .orm import ( # noqa + get_bind, + get_class_by_table, + get_column_key, + get_columns, + get_declarative_base, + get_hybrid_properties, + get_mapper, + get_primary_keys, + get_query_entities, + get_tables, + get_type, + getdotattr, + has_changes, + identity, + is_loaded, + naturally_equivalent, + quote, + table_name +) +from .render import render_expression, render_statement # noqa +from .sort_query import ( # noqa + make_order_by_deterministic, + QuerySorterException, + sort_query +) diff --git a/python-sqlalchemy-utils/sqlalchemy_utils/functions/database.py b/python-sqlalchemy-utils/sqlalchemy_utils/functions/database.py new file mode 100644 index 0000000..040f240 --- /dev/null +++ b/python-sqlalchemy-utils/sqlalchemy_utils/functions/database.py @@ -0,0 +1,513 @@ +import collections +import itertools +import os +from copy import copy + +import sqlalchemy as sa +from sqlalchemy.engine.url import make_url +from sqlalchemy.exc import OperationalError, ProgrammingError + +from sqlalchemy_utils.expressions import explain_analyze + +from .orm import quote + + +class PlanAnalysis(object): + def __init__(self, plan): + self.plan = plan + + @property + def node_types(self): + types = [self.plan['Node Type']] + if 'Plans' in self.plan: + for plan in self.plan['Plans']: + analysis = PlanAnalysis(plan) + types.extend(analysis.node_types) + return types + + +class QueryAnalysis(object): + def __init__(self, result_set): + self.plan = result_set[0]['Plan'] + if 'Total Runtime' in result_set[0]: + # PostgreSQL versions < 9.4 + self.runtime = result_set[0]['Total Runtime'] + else: + # PostgreSQL versions >= 9.4 + self.runtime = ( + result_set[0]['Execution Time'] + + result_set[0]['Planning Time'] + ) + + @property + def node_types(self): + return list(PlanAnalysis(self.plan).node_types) + + def __repr__(self): + return '' % self.runtime + + +def analyze(conn, query): + """ + Analyze query using given connection and return :class:`QueryAnalysis` + object. Analysis is performed using database specific EXPLAIN ANALYZE + construct and then examining the results into structured format. Currently + only PostgreSQL is supported. + + + Getting query runtime (in database level) :: + + + from sqlalchemy_utils import analyze + + + analysis = analyze(conn, 'SELECT * FROM article') + analysis.runtime # runtime as milliseconds + + + Analyze can be very useful when testing that query doesn't issue a + sequential scan (scanning all rows in table). You can for example write + simple performance tests this way.:: + + + query = ( + session.query(Article.name) + .order_by(Article.name) + .limit(10) + ) + analysis = analyze(self.connection, query) + analysis.node_types # [u'Limit', u'Index Only Scan'] + + assert 'Seq Scan' not in analysis.node_types + + + .. versionadded: 0.26.17 + + :param conn: SQLAlchemy Connection object + :param query: SQLAlchemy Query object or query as a string + """ + return QueryAnalysis( + conn.execute( + explain_analyze(query, buffers=True, format='json') + ).scalar() + ) + + +def escape_like(string, escape_char='*'): + """ + Escape the string paremeter used in SQL LIKE expressions. + + :: + + from sqlalchemy_utils import escape_like + + + query = session.query(User).filter( + User.name.ilike(escape_like('John')) + ) + + + :param string: a string to escape + :param escape_char: escape character + """ + return ( + string + .replace(escape_char, escape_char * 2) + .replace('%', escape_char + '%') + .replace('_', escape_char + '_') + ) + + +def json_sql(value, scalars_to_json=True): + """ + Convert python data structures to PostgreSQL specific SQLAlchemy JSON + constructs. This function is extremly useful if you need to build + PostgreSQL JSON on python side. + + .. note:: + + This function needs PostgreSQL >= 9.4 + + Scalars are converted to to_json SQLAlchemy function objects + + :: + + json_sql(1) # Equals SQL: to_json(1) + + json_sql('a') # to_json('a') + + + Mappings are converted to json_build_object constructs + + :: + + json_sql({'a': 'c', '2': 5}) # json_build_object('a', 'c', '2', 5) + + + Sequences (other than strings) are converted to json_build_array constructs + + :: + + json_sql([1, 2, 3]) # json_build_array(1, 2, 3) + + + You can also nest these data structures + + :: + + json_sql({'a': [1, 2, 3]}) + # json_build_object('a', json_build_array[1, 2, 3]) + + + :param value: + value to be converted to SQLAlchemy PostgreSQL function constructs + """ + scalar_convert = sa.text + if scalars_to_json: + scalar_convert = lambda a: sa.func.to_json(sa.text(a)) + + if isinstance(value, collections.Mapping): + return sa.func.json_build_object( + *( + json_sql(v, scalars_to_json=False) + for v in itertools.chain(*value.items()) + ) + ) + elif isinstance(value, str): + return scalar_convert("'{0}'".format(value)) + elif isinstance(value, collections.Sequence): + return sa.func.json_build_array( + *( + json_sql(v, scalars_to_json=False) + for v in value + ) + ) + elif isinstance(value, (int, float)): + return scalar_convert(str(value)) + return value + + +def has_index(column): + """ + Return whether or not given column has an index. A column has an index if + it has a single column index or it is the first column in compound column + index. + + :param column: SQLAlchemy Column object + + .. versionadded: 0.26.2 + + :: + + from sqlalchemy_utils import has_index + + + class Article(Base): + __tablename__ = 'article' + id = sa.Column(sa.Integer, primary_key=True) + title = sa.Column(sa.String(100)) + is_published = sa.Column(sa.Boolean, index=True) + is_deleted = sa.Column(sa.Boolean) + is_archived = sa.Column(sa.Boolean) + + __table_args__ = ( + sa.Index('my_index', is_deleted, is_archived), + ) + + + table = Article.__table__ + + has_index(table.c.is_published) # True + has_index(table.c.is_deleted) # True + has_index(table.c.is_archived) # False + + + Also supports primary key indexes + + :: + + from sqlalchemy_utils import has_index + + + class ArticleTranslation(Base): + __tablename__ = 'article_translation' + id = sa.Column(sa.Integer, primary_key=True) + locale = sa.Column(sa.String(10), primary_key=True) + title = sa.Column(sa.String(100)) + + + table = ArticleTranslation.__table__ + + has_index(table.c.locale) # False + has_index(table.c.id) # True + """ + table = column.table + if not isinstance(table, sa.Table): + raise TypeError( + 'Only columns belonging to Table objects are supported. Given ' + 'column belongs to %r.' % table + ) + primary_keys = table.primary_key.columns.values() + return ( + (primary_keys and column is primary_keys[0]) + or + any( + index.columns.values()[0] is column + for index in table.indexes + ) + ) + + +def has_unique_index(column): + """ + Return whether or not given column has a unique index. A column has a + unique index if it has a single column primary key index or it has a + single column UniqueConstraint. + + :param column: SQLAlchemy Column object + + .. versionadded: 0.27.1 + + :: + + from sqlalchemy_utils import has_unique_index + + + class Article(Base): + __tablename__ = 'article' + id = sa.Column(sa.Integer, primary_key=True) + title = sa.Column(sa.String(100)) + is_published = sa.Column(sa.Boolean, unique=True) + is_deleted = sa.Column(sa.Boolean) + is_archived = sa.Column(sa.Boolean) + + + table = Article.__table__ + + has_unique_index(table.c.is_published) # True + has_unique_index(table.c.is_deleted) # False + has_unique_index(table.c.id) # True + + + :raises TypeError: if given column does not belong to a Table object + """ + table = column.table + if not isinstance(table, sa.Table): + raise TypeError( + 'Only columns belonging to Table objects are supported. Given ' + 'column belongs to %r.' % table + ) + pks = table.primary_key.columns + return ( + (column is pks.values()[0] and len(pks) == 1) + or + any( + match_columns(constraint.columns.values()[0], column) and + len(constraint.columns) == 1 + for constraint in column.table.constraints + if isinstance(constraint, sa.sql.schema.UniqueConstraint) + ) + ) + + +def match_columns(column, column2): + return column.table is column2.table and column.name == column2.name + + +def is_auto_assigned_date_column(column): + """ + Returns whether or not given SQLAlchemy Column object's is auto assigned + DateTime or Date. + + :param column: SQLAlchemy Column object + """ + return ( + ( + isinstance(column.type, sa.DateTime) or + isinstance(column.type, sa.Date) + ) + and + ( + column.default or + column.server_default or + column.onupdate or + column.server_onupdate + ) + ) + + +def database_exists(url): + """Check if a database exists. + + :param url: A SQLAlchemy engine URL. + + Performs backend-specific testing to quickly determine if a database + exists on the server. :: + + database_exists('postgres://postgres@localhost/name') #=> False + create_database('postgres://postgres@localhost/name') + database_exists('postgres://postgres@localhost/name') #=> True + + Supports checking against a constructed URL as well. :: + + engine = create_engine('postgres://postgres@localhost/name') + database_exists(engine.url) #=> False + create_database(engine.url) + database_exists(engine.url) #=> True + + """ + + url = copy(make_url(url)) + database = url.database + if url.drivername.startswith('postgresql'): + url.database = 'template1' + else: + url.database = None + + engine = sa.create_engine(url) + + if engine.dialect.name == 'postgresql': + text = "SELECT 1 FROM pg_database WHERE datname='%s'" % database + return bool(engine.execute(text).scalar()) + + elif engine.dialect.name == 'mysql': + text = ("SELECT SCHEMA_NAME FROM INFORMATION_SCHEMA.SCHEMATA " + "WHERE SCHEMA_NAME = '%s'" % database) + return bool(engine.execute(text).scalar()) + + elif engine.dialect.name == 'sqlite': + return database == ':memory:' or os.path.exists(database) + + else: + text = 'SELECT 1' + try: + url.database = database + engine = sa.create_engine(url) + engine.execute(text) + return True + + except (ProgrammingError, OperationalError): + return False + + +def create_database(url, encoding='utf8', template=None): + """Issue the appropriate CREATE DATABASE statement. + + :param url: A SQLAlchemy engine URL. + :param encoding: The encoding to create the database as. + :param template: + The name of the template from which to create the new database. At the + moment only supported by PostgreSQL driver. + + To create a database, you can pass a simple URL that would have + been passed to ``create_engine``. :: + + create_database('postgres://postgres@localhost/name') + + You may also pass the url from an existing engine. :: + + create_database(engine.url) + + Has full support for mysql, postgres, and sqlite. In theory, + other database engines should be supported. + """ + + url = copy(make_url(url)) + + database = url.database + + if url.drivername.startswith('postgresql'): + url.database = 'template1' + elif not url.drivername.startswith('sqlite'): + url.database = None + + engine = sa.create_engine(url) + + if engine.dialect.name == 'postgresql': + if engine.driver == 'psycopg2': + from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT + engine.raw_connection().set_isolation_level( + ISOLATION_LEVEL_AUTOCOMMIT + ) + + if not template: + template = 'template0' + + text = "CREATE DATABASE {0} ENCODING '{1}' TEMPLATE {2}".format( + quote(engine, database), + encoding, + quote(engine, template) + ) + engine.execute(text) + + elif engine.dialect.name == 'mysql': + text = "CREATE DATABASE {0} CHARACTER SET = '{1}'".format( + quote(engine, database), + encoding + ) + engine.execute(text) + + elif engine.dialect.name == 'sqlite' and database != ':memory:': + open(database, 'w').close() + + else: + text = 'CREATE DATABASE {0}'.format(quote(engine, database)) + engine.execute(text) + + +def drop_database(url): + """Issue the appropriate DROP DATABASE statement. + + :param url: A SQLAlchemy engine URL. + + Works similar to the :ref:`create_database` method in that both url text + and a constructed url are accepted. :: + + drop_database('postgres://postgres@localhost/name') + drop_database(engine.url) + + """ + + url = copy(make_url(url)) + + database = url.database + + if url.drivername.startswith('postgresql'): + url.database = 'template1' + elif not url.drivername.startswith('sqlite'): + url.database = None + + engine = sa.create_engine(url) + + if engine.dialect.name == 'sqlite' and url.database != ':memory:': + os.remove(url.database) + + elif engine.dialect.name == 'postgresql' and engine.driver == 'psycopg2': + from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT + engine.raw_connection().set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT) + + # Disconnect all users from the database we are dropping. + version = list( + map( + int, + engine.execute('SHOW server_version').first()[0].split('.') + ) + ) + pid_column = ( + 'pid' if (version[0] >= 9 and version[1] >= 2) else 'procpid' + ) + text = ''' + SELECT pg_terminate_backend(pg_stat_activity.%(pid_column)s) + FROM pg_stat_activity + WHERE pg_stat_activity.datname = '%(database)s' + AND %(pid_column)s <> pg_backend_pid(); + ''' % {'pid_column': pid_column, 'database': database} + engine.execute(text) + + # Drop the database. + text = 'DROP DATABASE {0}'.format(quote(engine, database)) + engine.execute(text) + + else: + text = 'DROP DATABASE {0}'.format(quote(engine, database)) + engine.execute(text) diff --git a/python-sqlalchemy-utils/sqlalchemy_utils/functions/foreign_keys.py b/python-sqlalchemy-utils/sqlalchemy_utils/functions/foreign_keys.py new file mode 100644 index 0000000..81cd782 --- /dev/null +++ b/python-sqlalchemy-utils/sqlalchemy_utils/functions/foreign_keys.py @@ -0,0 +1,365 @@ +from collections import defaultdict +from itertools import groupby + +import six +import sqlalchemy as sa +from sqlalchemy.exc import NoInspectionAvailable +from sqlalchemy.orm import object_session +from sqlalchemy.schema import ForeignKeyConstraint, MetaData, Table + +from ..query_chain import QueryChain +from .orm import get_column_key, get_mapper, get_tables + + +def get_foreign_key_values(fk, obj): + return dict( + ( + fk.constraint.columns.values()[index].key, + getattr(obj, element.column.key) + ) + for + index, element + in + enumerate(fk.constraint.elements) + ) + + +def group_foreign_keys(foreign_keys): + """ + Return a groupby iterator that groups given foreign keys by table. + + :param foreign_keys: a sequence of foreign keys + + + :: + + foreign_keys = get_referencing_foreign_keys(User) + + for table, fks in group_foreign_keys(foreign_keys): + # do something + pass + + + .. seealso:: :func:`get_referencing_foreign_keys` + + .. versionadded: 0.26.1 + """ + foreign_keys = sorted( + foreign_keys, key=lambda key: key.constraint.table.name + ) + return groupby(foreign_keys, lambda key: key.constraint.table) + + +def get_referencing_foreign_keys(mixed): + """ + Returns referencing foreign keys for given Table object or declarative + class. + + :param mixed: + SA Table object or SA declarative class + + :: + + get_referencing_foreign_keys(User) # set([ForeignKey('user.id')]) + + get_referencing_foreign_keys(User.__table__) + + + This function also understands inheritance. This means it returns + all foreign keys that reference any table in the class inheritance tree. + + Let's say you have three classes which use joined table inheritance, + namely TextItem, Article and BlogPost with Article and BlogPost inheriting + TextItem. + + :: + + # This will check all foreign keys that reference either article table + # or textitem table. + get_referencing_foreign_keys(Article) + + .. seealso:: :func:`get_tables` + """ + if isinstance(mixed, sa.Table): + tables = [mixed] + else: + tables = get_tables(mixed) + + referencing_foreign_keys = set() + + for table in mixed.metadata.tables.values(): + if table not in tables: + for constraint in table.constraints: + if isinstance(constraint, sa.sql.schema.ForeignKeyConstraint): + for fk in constraint.elements: + if any(fk.references(t) for t in tables): + referencing_foreign_keys.add(fk) + return referencing_foreign_keys + + +def merge_references(from_, to, foreign_keys=None): + """ + Merge the references of an entity into another entity. + + Consider the following models:: + + class User(self.Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.String(255)) + + def __repr__(self): + return 'User(name=%r)' % self.name + + class BlogPost(self.Base): + __tablename__ = 'blog_post' + id = sa.Column(sa.Integer, primary_key=True) + title = sa.Column(sa.String(255)) + author_id = sa.Column(sa.Integer, sa.ForeignKey('user.id')) + + author = sa.orm.relationship(User) + + + Now lets add some data:: + + john = self.User(name='John') + jack = self.User(name='Jack') + post = self.BlogPost(title='Some title', author=john) + post2 = self.BlogPost(title='Other title', author=jack) + self.session.add_all([ + john, + jack, + post, + post2 + ]) + self.session.commit() + + + If we wanted to merge all John's references to Jack it would be as easy as + :: + + merge_references(john, jack) + self.session.commit() + + post.author # User(name='Jack') + post2.author # User(name='Jack') + + + + :param from_: an entity to merge into another entity + :param to: an entity to merge another entity into + :param foreign_keys: A sequence of foreign keys. By default this is None + indicating all referencing foreign keys should be used. + + .. seealso: :func:`dependent_objects` + + .. versionadded: 0.26.1 + """ + if from_.__tablename__ != to.__tablename__: + raise TypeError('The tables of given arguments do not match.') + + session = object_session(from_) + foreign_keys = get_referencing_foreign_keys(from_) + + for fk in foreign_keys: + old_values = get_foreign_key_values(fk, from_) + new_values = get_foreign_key_values(fk, to) + criteria = ( + getattr(fk.constraint.table.c, key) == value + for key, value in six.iteritems(old_values) + ) + try: + mapper = get_mapper(fk.constraint.table) + except ValueError: + query = ( + fk.constraint.table + .update() + .where(sa.and_(*criteria)) + .values(new_values) + ) + session.execute(query) + else: + ( + session.query(mapper.class_) + .filter_by(**old_values) + .update( + new_values, + 'evaluate' + ) + ) + + +def dependent_objects(obj, foreign_keys=None): + """ + Return a :class:`~sqlalchemy_utils.query_chain.QueryChain` that iterates + through all dependent objects for given SQLAlchemy object. + + Consider a User object is referenced in various articles and also in + various orders. Getting all these dependent objects is as easy as:: + + from sqlalchemy_utils import dependent_objects + + + dependent_objects(user) + + + If you expect an object to have lots of dependent_objects it might be good + to limit the results:: + + + dependent_objects(user).limit(5) + + + + The common use case is checking for all restrict dependent objects before + deleting parent object and inform the user if there are dependent objects + with ondelete='RESTRICT' foreign keys. If this kind of checking is not used + it will lead to nasty IntegrityErrors being raised. + + In the following example we delete given user if it doesn't have any + foreign key restricted dependent objects:: + + + from sqlalchemy_utils import get_referencing_foreign_keys + + + user = session.query(User).get(some_user_id) + + + deps = list( + dependent_objects( + user, + ( + fk for fk in get_referencing_foreign_keys(User) + # On most databases RESTRICT is the default mode hence we + # check for None values also + if fk.ondelete == 'RESTRICT' or fk.ondelete is None + ) + ).limit(5) + ) + + if deps: + # Do something to inform the user + pass + else: + session.delete(user) + + + :param obj: SQLAlchemy declarative model object + :param foreign_keys: + A sequence of foreign keys to use for searching the dependent_objects + for given object. By default this is None, indicating that all foreign + keys referencing the object will be used. + + .. note:: + This function does not support exotic mappers that use multiple tables + + .. seealso:: :func:`get_referencing_foreign_keys` + .. seealso:: :func:`merge_references` + + .. versionadded: 0.26.0 + """ + if foreign_keys is None: + foreign_keys = get_referencing_foreign_keys(obj) + + session = object_session(obj) + + chain = QueryChain([]) + classes = obj.__class__._decl_class_registry + + for table, keys in group_foreign_keys(foreign_keys): + keys = list(keys) + for class_ in classes.values(): + try: + mapper = sa.inspect(class_) + except NoInspectionAvailable: + continue + parent_mapper = mapper.inherits + if ( + table in mapper.tables and + not (parent_mapper and table in parent_mapper.tables) + ): + query = session.query(class_).filter( + sa.or_(*_get_criteria(keys, class_, obj)) + ) + chain.queries.append(query) + return chain + + +def _get_criteria(keys, class_, obj): + criteria = [] + visited_constraints = [] + for key in keys: + if key.constraint in visited_constraints: + continue + visited_constraints.append(key.constraint) + + subcriteria = [] + for index, column in enumerate(key.constraint.columns): + foreign_column = ( + key.constraint.elements[index].column + ) + subcriteria.append( + getattr(class_, get_column_key(class_, column)) == + getattr( + obj, + sa.inspect(type(obj)) + .get_property_by_column( + foreign_column + ).key + ) + ) + criteria.append(sa.and_(*subcriteria)) + return criteria + + +def non_indexed_foreign_keys(metadata, engine=None): + """ + Finds all non indexed foreign keys from all tables of given MetaData. + + Very useful for optimizing postgresql database and finding out which + foreign keys need indexes. + + :param metadata: MetaData object to inspect tables from + """ + reflected_metadata = MetaData() + + if metadata.bind is None and engine is None: + raise Exception( + 'Either pass a metadata object with bind or ' + 'pass engine as a second parameter' + ) + + constraints = defaultdict(list) + + for table_name in metadata.tables.keys(): + table = Table( + table_name, + reflected_metadata, + autoload=True, + autoload_with=metadata.bind or engine + ) + + for constraint in table.constraints: + if not isinstance(constraint, ForeignKeyConstraint): + continue + + if not is_indexed_foreign_key(constraint): + constraints[table.name].append(constraint) + + return dict(constraints) + + +def is_indexed_foreign_key(constraint): + """ + Whether or not given foreign key constraint's columns have been indexed. + + :param constraint: ForeignKeyConstraint object to check the indexes + """ + return any( + set(constraint.columns.keys()) + == + set(column.name for column in index.columns) + for index + in constraint.table.indexes + ) diff --git a/python-sqlalchemy-utils/sqlalchemy_utils/functions/mock.py b/python-sqlalchemy-utils/sqlalchemy_utils/functions/mock.py new file mode 100644 index 0000000..258e4da --- /dev/null +++ b/python-sqlalchemy-utils/sqlalchemy_utils/functions/mock.py @@ -0,0 +1,109 @@ +import contextlib +import datetime +import inspect +import re + +import six +import sqlalchemy as sa + + +def create_mock_engine(bind, stream=None): + """Create a mock SQLAlchemy engine from the passed engine or bind URL. + + :param bind: A SQLAlchemy engine or bind URL to mock. + :param stream: Render all DDL operations to the stream. + """ + + if not isinstance(bind, six.string_types): + bind_url = str(bind.url) + + else: + bind_url = bind + + if stream is not None: + + def dump(sql, *args, **kwargs): + + class Compiler(type(sql._compiler(engine.dialect))): + + def visit_bindparam(self, bindparam, *args, **kwargs): + return self.render_literal_value( + bindparam.value, bindparam.type) + + def render_literal_value(self, value, type_): + if isinstance(value, six.integer_types): + return str(value) + + elif isinstance(value, (datetime.date, datetime.datetime)): + return "'%s'" % value + + return super(Compiler, self).render_literal_value( + value, type_) + + text = str(Compiler(engine.dialect, sql).process(sql)) + text = re.sub(r'\n+', '\n', text) + text = text.strip('\n').strip() + + stream.write('\n%s;' % text) + + else: + + dump = lambda *a, **kw: None + + engine = sa.create_engine(bind_url, strategy='mock', executor=dump) + return engine + + +@contextlib.contextmanager +def mock_engine(engine, stream=None): + """Mocks out the engine specified in the passed bind expression. + + Note this function is meant for convenience and protected usage. Do NOT + blindly pass user input to this function as it uses exec. + + :param engine: A python expression that represents the engine to mock. + :param stream: Render all DDL operations to the stream. + """ + + # Create a stream if not present. + + if stream is None: + stream = six.moves.cStringIO() + + # Navigate the stack and find the calling frame that allows the + # expression to execuate. + + for frame in inspect.stack()[1:]: + + try: + frame = frame[0] + expression = '__target = %s' % engine + six.exec_(expression, frame.f_globals, frame.f_locals) + target = frame.f_locals['__target'] + break + + except: + pass + + else: + + raise ValueError('Not a valid python expression', engine) + + # Evaluate the expression and get the target engine. + + frame.f_locals['__mock'] = create_mock_engine(target, stream) + + # Replace the target with our mock. + + six.exec_('%s = __mock' % engine, frame.f_globals, frame.f_locals) + + # Give control back. + + yield stream + + # Put the target engine back. + + frame.f_locals['__target'] = target + six.exec_('%s = __target' % engine, frame.f_globals, frame.f_locals) + six.exec_('del __target', frame.f_globals, frame.f_locals) + six.exec_('del __mock', frame.f_globals, frame.f_locals) diff --git a/python-sqlalchemy-utils/sqlalchemy_utils/functions/orm.py b/python-sqlalchemy-utils/sqlalchemy_utils/functions/orm.py new file mode 100644 index 0000000..d90c272 --- /dev/null +++ b/python-sqlalchemy-utils/sqlalchemy_utils/functions/orm.py @@ -0,0 +1,930 @@ +try: + from collections import OrderedDict +except ImportError: + from ordereddict import OrderedDict + +from functools import partial +from inspect import isclass +from operator import attrgetter + +import six +import sqlalchemy as sa +from sqlalchemy.engine.interfaces import Dialect +from sqlalchemy.ext.hybrid import hybrid_property +from sqlalchemy.orm import mapperlib +from sqlalchemy.orm.attributes import InstrumentedAttribute +from sqlalchemy.orm.exc import UnmappedInstanceError +from sqlalchemy.orm.properties import ColumnProperty, RelationshipProperty +from sqlalchemy.orm.query import _ColumnEntity +from sqlalchemy.orm.session import object_session +from sqlalchemy.orm.util import AliasedInsp + +from sqlalchemy_utils.utils import is_sequence + + +def get_class_by_table(base, table, data=None): + """ + Return declarative class associated with given table. If no class is found + this function returns `None`. If multiple classes were found (polymorphic + cases) additional `data` parameter can be given to hint which class + to return. + + :: + + class User(Base): + __tablename__ = 'entity' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.String) + + + get_class_by_table(Base, User.__table__) # User class + + + This function also supports models using single table inheritance. + Additional data paratemer should be provided in these case. + + :: + + class Entity(Base): + __tablename__ = 'entity' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.String) + type = sa.Column(sa.String) + __mapper_args__ = { + 'polymorphic_on': type, + 'polymorphic_identity': 'entity' + } + + class User(Entity): + __mapper_args__ = { + 'polymorphic_identity': 'user' + } + + + # Entity class + get_class_by_table(Base, Entity.__table__, {'type': 'entity'}) + + # User class + get_class_by_table(Base, Entity.__table__, {'type': 'user'}) + + + :param base: Declarative model base + :param table: SQLAlchemy Table object + :param data: Data row to determine the class in polymorphic scenarios + :return: Declarative class or None. + """ + found_classes = set( + c for c in base._decl_class_registry.values() + if hasattr(c, '__table__') and c.__table__ is table + ) + if len(found_classes) > 1: + if not data: + raise ValueError( + "Multiple declarative classes found for table '{0}'. " + "Please provide data parameter for this function to be able " + "to determine polymorphic scenarios.".format( + table.name + ) + ) + else: + for cls in found_classes: + mapper = sa.inspect(cls) + polymorphic_on = mapper.polymorphic_on.name + if polymorphic_on in data: + if data[polymorphic_on] == mapper.polymorphic_identity: + return cls + raise ValueError( + "Multiple declarative classes found for table '{0}'. Given " + "data row does not match any polymorphic identity of the " + "found classes.".format( + table.name + ) + ) + elif found_classes: + return found_classes.pop() + return None + + +def get_type(expr): + """ + Return the associated type with given Column, InstrumentedAttribute, + ColumnProperty, RelationshipProperty or other similar SQLAlchemy construct. + + For constructs wrapping columns this is the column type. For relationships + this function returns the relationship mapper class. + + :param expr: + SQLAlchemy Column, InstrumentedAttribute, ColumnProperty or other + similar SA construct. + + :: + + class User(Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.String) + + + class Article(Base): + __tablename__ = 'article' + id = sa.Column(sa.Integer, primary_key=True) + author_id = sa.Column(sa.Integer, sa.ForeignKey(User.id)) + author = sa.orm.relationship(User) + + + get_type(User.__table__.c.name) # sa.String() + get_type(User.name) # sa.String() + get_type(User.name.property) # sa.String() + + get_type(Article.author) # User + + + .. versionadded: 0.30.9 + """ + if hasattr(expr, 'type'): + return expr.type + elif isinstance(expr, InstrumentedAttribute): + expr = expr.property + + if isinstance(expr, ColumnProperty): + return expr.columns[0].type + elif isinstance(expr, RelationshipProperty): + return expr.mapper.class_ + raise TypeError("Couldn't inspect type.") + + +def get_column_key(model, column): + """ + Return the key for given column in given model. + + :param model: SQLAlchemy declarative model object + + :: + + class User(Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column('_name', sa.String) + + + get_column_key(User, User.__table__.c._name) # 'name' + + .. versionadded: 0.26.5 + + .. versionchanged: 0.27.11 + Throws UnmappedColumnError instead of ValueError when no property was + found for given column. This is consistent with how SQLAlchemy works. + """ + mapper = sa.inspect(model) + try: + return mapper.get_property_by_column(column).key + except sa.orm.exc.UnmappedColumnError: + for key, c in mapper.columns.items(): + if c.name == column.name and c.table is column.table: + return key + raise sa.orm.exc.UnmappedColumnError( + 'No column %s is configured on mapper %s...' % + (column, mapper) + ) + + +def get_mapper(mixed): + """ + Return related SQLAlchemy Mapper for given SQLAlchemy object. + + :param mixed: SQLAlchemy Table / Alias / Mapper / declarative model object + + :: + + from sqlalchemy_utils import get_mapper + + + get_mapper(User) + + get_mapper(User()) + + get_mapper(User.__table__) + + get_mapper(User.__mapper__) + + get_mapper(sa.orm.aliased(User)) + + get_mapper(sa.orm.aliased(User.__table__)) + + + Raises: + ValueError: if multiple mappers were found for given argument + + .. versionadded: 0.26.1 + """ + if isinstance(mixed, sa.orm.query._MapperEntity): + mixed = mixed.expr + elif isinstance(mixed, sa.Column): + mixed = mixed.table + elif isinstance(mixed, sa.orm.query._ColumnEntity): + mixed = mixed.expr + + if isinstance(mixed, sa.orm.Mapper): + return mixed + if isinstance(mixed, sa.orm.util.AliasedClass): + return sa.inspect(mixed).mapper + if isinstance(mixed, sa.sql.selectable.Alias): + mixed = mixed.element + if isinstance(mixed, AliasedInsp): + return mixed.mapper + if isinstance(mixed, sa.orm.attributes.InstrumentedAttribute): + mixed = mixed.class_ + if isinstance(mixed, sa.Table): + mappers = [ + mapper for mapper in mapperlib._mapper_registry + if mixed in mapper.tables + ] + if len(mappers) > 1: + raise ValueError( + "Multiple mappers found for table '%s'." % mixed.name + ) + elif not mappers: + raise ValueError( + "Could not get mapper for table '%s'." % mixed.name + ) + else: + return mappers[0] + if not isclass(mixed): + mixed = type(mixed) + return sa.inspect(mixed) + + +def get_bind(obj): + """ + Return the bind for given SQLAlchemy Engine / Connection / declarative + model object. + + :param obj: SQLAlchemy Engine / Connection / declarative model object + + :: + + from sqlalchemy_utils import get_bind + + + get_bind(session) # Connection object + + get_bind(user) + + """ + if hasattr(obj, 'bind'): + conn = obj.bind + else: + try: + conn = object_session(obj).bind + except UnmappedInstanceError: + conn = obj + + if not hasattr(conn, 'execute'): + raise TypeError( + 'This method accepts only Session, Engine, Connection and ' + 'declarative model objects.' + ) + return conn + + +def get_primary_keys(mixed): + """ + Return an OrderedDict of all primary keys for given Table object, + declarative class or declarative class instance. + + :param mixed: + SA Table object, SA declarative class or SA declarative class instance + + :: + + get_primary_keys(User) + + get_primary_keys(User()) + + get_primary_keys(User.__table__) + + get_primary_keys(User.__mapper__) + + get_primary_keys(sa.orm.aliased(User)) + + get_primary_keys(sa.orm.aliased(User.__table__)) + + + .. versionchanged: 0.25.3 + Made the function return an ordered dictionary instead of generator. + This change was made to support primary key aliases. + + Renamed this function to 'get_primary_keys', formerly 'primary_keys' + + .. seealso:: :func:`get_columns` + """ + return OrderedDict( + ( + (key, column) for key, column in get_columns(mixed).items() + if column.primary_key + ) + ) + + +def get_tables(mixed): + """ + Return a set of tables associated with given SQLAlchemy object. + + Let's say we have three classes which use joined table inheritance + TextItem, Article and BlogPost. Article and BlogPost inherit TextItem. + + :: + + get_tables(Article) # set([Table('article', ...), Table('text_item')]) + + get_tables(Article()) + + get_tables(Article.__mapper__) + + + If the TextItem entity is using with_polymorphic='*' then this function + returns all child tables (article and blog_post) as well. + + :: + + + get_tables(TextItem) # set([Table('text_item', ...)], ...]) + + + .. versionadded: 0.26.0 + + :param mixed: + SQLAlchemy Mapper, Declarative class, Column, InstrumentedAttribute or + a SA Alias object wrapping any of these objects. + """ + if isinstance(mixed, sa.Table): + return [mixed] + elif isinstance(mixed, sa.Column): + return [mixed.table] + elif isinstance(mixed, sa.orm.attributes.InstrumentedAttribute): + return mixed.parent.tables + elif isinstance(mixed, sa.orm.query._ColumnEntity): + mixed = mixed.expr + + mapper = get_mapper(mixed) + + polymorphic_mappers = get_polymorphic_mappers(mapper) + if polymorphic_mappers: + tables = sum((m.tables for m in polymorphic_mappers), []) + else: + tables = mapper.tables + return tables + + +def get_columns(mixed): + """ + Return a collection of all Column objects for given SQLAlchemy + object. + + The type of the collection depends on the type of the object to return the + columns from. + + :: + + get_columns(User) + + get_columns(User()) + + get_columns(User.__table__) + + get_columns(User.__mapper__) + + get_columns(sa.orm.aliased(User)) + + get_columns(sa.orm.alised(User.__table__)) + + + :param mixed: + SA Table object, SA Mapper, SA declarative class, SA declarative class + instance or an alias of any of these objects + """ + if isinstance(mixed, sa.Table): + return mixed.c + if isinstance(mixed, sa.orm.util.AliasedClass): + return sa.inspect(mixed).mapper.columns + if isinstance(mixed, sa.sql.selectable.Alias): + return mixed.c + if isinstance(mixed, sa.orm.Mapper): + return mixed.columns + if not isclass(mixed): + mixed = mixed.__class__ + return sa.inspect(mixed).columns + + +def table_name(obj): + """ + Return table name of given target, declarative class or the + table name where the declarative attribute is bound to. + """ + class_ = getattr(obj, 'class_', obj) + + try: + return class_.__tablename__ + except AttributeError: + pass + + try: + return class_.__table__.name + except AttributeError: + pass + + +def getattrs(obj, attrs): + return map(partial(getattr, obj), attrs) + + +def quote(mixed, ident): + """ + Conditionally quote an identifier. + :: + + + from sqlalchemy_utils import quote + + + engine = create_engine('sqlite:///:memory:') + + quote(engine, 'order') + # '"order"' + + quote(engine, 'some_other_identifier') + # 'some_other_identifier' + + + :param mixed: SQLAlchemy Session / Connection / Engine / Dialect object. + :param ident: identifier to conditionally quote + """ + if isinstance(mixed, Dialect): + dialect = mixed + else: + dialect = get_bind(mixed).dialect + return dialect.preparer(dialect).quote(ident) + + +def query_labels(query): + """ + Return all labels for given SQLAlchemy query object. + + Example:: + + + query = session.query( + Category, + db.func.count(Article.id).label('articles') + ) + + query_labels(query) # ['articles'] + + :param query: SQLAlchemy Query object + """ + return [ + entity._label_name for entity in query._entities + if isinstance(entity, _ColumnEntity) and entity._label_name + ] + + +def get_query_entities(query): + """ + Return a list of all entities present in given SQLAlchemy query object. + + Examples:: + + + from sqlalchemy_utils import get_query_entities + + + query = session.query(Category) + + get_query_entities(query) # [] + + + query = session.query(Category.id) + + get_query_entities(query) # [] + + + This function also supports queries with joins. + + :: + + + query = session.query(Category).join(Article) + + get_query_entities(query) # [,
] + + .. versionchanged: 0.26.7 + This function now returns a list instead of generator + + :param query: SQLAlchemy Query object + """ + exprs = [ + d['expr'] + if is_labeled_query(d['expr']) or isinstance(d['expr'], sa.Column) + else d['entity'] + for d in query.column_descriptions + ] + return [ + get_query_entity(expr) for expr in exprs + ] + [ + get_query_entity(entity) for entity in query._join_entities + ] + + +def is_labeled_query(expr): + return ( + isinstance(expr, sa.sql.elements.Label) and + isinstance( + list(expr.base_columns)[0], + (sa.sql.selectable.Select, sa.sql.selectable.ScalarSelect) + ) + ) + + +def get_query_entity(expr): + if isinstance(expr, sa.orm.attributes.InstrumentedAttribute): + return expr.parent.class_ + elif isinstance(expr, sa.Column): + return expr.table + elif isinstance(expr, AliasedInsp): + return expr.entity + return expr + + +def get_query_entity_by_alias(query, alias): + entities = get_query_entities(query) + + if not alias: + return entities[0] + + for entity in entities: + if isinstance(entity, sa.orm.util.AliasedClass): + name = sa.inspect(entity).name + else: + name = get_mapper(entity).tables[0].name + + if name == alias: + return entity + + +def get_polymorphic_mappers(mixed): + if isinstance(mixed, AliasedInsp): + return mixed.with_polymorphic_mappers + else: + return mixed.polymorphic_map.values() + + +def get_query_descriptor(query, entity, attr): + if attr in query_labels(query): + return attr + else: + entity = get_query_entity_by_alias(query, entity) + if entity: + descriptor = get_descriptor(entity, attr) + if ( + hasattr(descriptor, 'property') and + isinstance(descriptor.property, sa.orm.RelationshipProperty) + ): + return + return descriptor + + +def get_descriptor(entity, attr): + mapper = sa.inspect(entity) + + for key, descriptor in get_all_descriptors(mapper).items(): + if attr == key: + prop = ( + descriptor.property + if hasattr(descriptor, 'property') + else None + ) + if isinstance(prop, ColumnProperty): + if isinstance(entity, sa.orm.util.AliasedClass): + for c in mapper.selectable.c: + if c.key == attr: + return c + else: + # If the property belongs to a class that uses + # polymorphic inheritance we have to take into account + # situations where the attribute exists in child class + # but not in parent class. + return getattr(prop.parent.class_, attr) + else: + # Handle synonyms, relationship properties and hybrid + # properties + try: + return getattr(mapper.class_, attr) + except AttributeError: + pass + + +def get_all_descriptors(expr): + insp = sa.inspect(expr) + polymorphic_mappers = get_polymorphic_mappers(insp) + if polymorphic_mappers: + attrs = dict(get_mapper(expr).all_orm_descriptors) + for submapper in polymorphic_mappers: + for key, descriptor in submapper.all_orm_descriptors.items(): + if key not in attrs: + attrs[key] = descriptor + return attrs + return get_mapper(expr).all_orm_descriptors + + +def get_hybrid_properties(model): + """ + Returns a dictionary of hybrid property keys and hybrid properties for + given SQLAlchemy declarative model / mapper. + + + Consider the following model + + :: + + + from sqlalchemy.ext.hybrid import hybrid_property + + + class Category(Base): + __tablename__ = 'category' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + @hybrid_property + def lowercase_name(self): + return self.name.lower() + + @lowercase_name.expression + def lowercase_name(cls): + return sa.func.lower(cls.name) + + + You can now easily get a list of all hybrid property names + + :: + + + from sqlalchemy_utils import get_hybrid_properties + + + get_hybrid_properties(Category).keys() # ['lowercase_name'] + + + .. versionchanged: 0.26.7 + This function now returns a dictionary instead of generator + + :param model: SQLAlchemy declarative model or mapper + """ + return dict( + (key, prop) + for key, prop in sa.inspect(model).all_orm_descriptors.items() + if isinstance(prop, hybrid_property) + ) + + +def get_declarative_base(model): + """ + Returns the declarative base for given model class. + + :param model: SQLAlchemy declarative model + """ + for parent in model.__bases__: + try: + parent.metadata + return get_declarative_base(parent) + except AttributeError: + pass + return model + + +def getdotattr(obj_or_class, dot_path, condition=None): + """ + Allow dot-notated strings to be passed to `getattr`. + + :: + + getdotattr(SubSection, 'section.document') + + getdotattr(subsection, 'section.document') + + + :param obj_or_class: Any object or class + :param dot_path: Attribute path with dot mark as separator + """ + last = obj_or_class + + for path in str(dot_path).split('.'): + getter = attrgetter(path) + + if is_sequence(last): + tmp = [] + for element in last: + value = getter(element) + if is_sequence(value): + tmp.extend(value) + else: + tmp.append(value) + last = tmp + elif isinstance(last, InstrumentedAttribute): + last = getter(last.property.mapper.class_) + elif last is None: + return None + else: + last = getter(last) + if condition is not None: + if is_sequence(last): + last = [v for v in last if condition(v)] + else: + if not condition(last): + return None + + return last + + +def is_deleted(obj): + return obj in sa.orm.object_session(obj).deleted + + +def has_changes(obj, attrs=None, exclude=None): + """ + Simple shortcut function for checking if given attributes of given + declarative model object have changed during the session. Without + parameters this checks if given object has any modificiations. Additionally + exclude parameter can be given to check if given object has any changes + in any attributes other than the ones given in exclude. + + + :: + + + from sqlalchemy_utils import has_changes + + + user = User() + + has_changes(user, 'name') # False + + user.name = u'someone' + + has_changes(user, 'name') # True + + has_changes(user) # True + + + You can check multiple attributes as well. + :: + + + has_changes(user, ['age']) # True + + has_changes(user, ['name', 'age']) # True + + + This function also supports excluding certain attributes. + + :: + + has_changes(user, exclude=['name']) # False + + has_changes(user, exclude=['age']) # True + + .. versionchanged: 0.26.6 + Added support for multiple attributes and exclude parameter. + + :param obj: SQLAlchemy declarative model object + :param attrs: Names of the attributes + :param exclude: Names of the attributes to exclude + """ + if attrs: + if isinstance(attrs, six.string_types): + return ( + sa.inspect(obj) + .attrs + .get(attrs) + .history + .has_changes() + ) + else: + return any(has_changes(obj, attr) for attr in attrs) + else: + if exclude is None: + exclude = [] + return any( + attr.history.has_changes() + for key, attr in sa.inspect(obj).attrs.items() + if key not in exclude + ) + + +def is_loaded(obj, prop): + """ + Return whether or not given property of given object has been loaded. + + :: + + class Article(Base): + __tablename__ = 'article' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.String) + content = sa.orm.deferred(sa.Column(sa.String)) + + + article = session.query(Article).get(5) + + # name gets loaded since its not a deferred property + assert is_loaded(article, 'name') + + # content has not yet been loaded since its a deferred property + assert not is_loaded(article, 'content') + + + .. versionadded: 0.27.8 + + :param obj: SQLAlchemy declarative model object + :param prop: Name of the property or InstrumentedAttribute + """ + return not isinstance( + getattr(sa.inspect(obj).attrs, prop).loaded_value, + sa.util.langhelpers._symbol + ) + + +def identity(obj_or_class): + """ + Return the identity of given sqlalchemy declarative model class or instance + as a tuple. This differs from obj._sa_instance_state.identity in a way that + it always returns the identity even if object is still in transient state ( + new object that is not yet persisted into database). Also for classes it + returns the identity attributes. + + :: + + from sqlalchemy import inspect + from sqlalchemy_utils import identity + + + user = User(name=u'John Matrix') + session.add(user) + identity(user) # None + inspect(user).identity # None + + session.flush() # User now has id but is still in transient state + + identity(user) # (1,) + inspect(user).identity # None + + session.commit() + + identity(user) # (1,) + inspect(user).identity # (1, ) + + + You can also use identity for classes:: + + + identity(User) # (User.id, ) + + .. versionadded: 0.21.0 + + :param obj: SQLAlchemy declarative model object + """ + return tuple( + getattr(obj_or_class, column_key) + for column_key in get_primary_keys(obj_or_class).keys() + ) + + +def naturally_equivalent(obj, obj2): + """ + Returns whether or not two given SQLAlchemy declarative instances are + naturally equivalent (all their non primary key properties are equivalent). + + + :: + + from sqlalchemy_utils import naturally_equivalent + + + user = User(name=u'someone') + user2 = User(name=u'someone') + + user == user2 # False + + naturally_equivalent(user, user2) # True + + + :param obj: SQLAlchemy declarative model object + :param obj2: SQLAlchemy declarative model object to compare with `obj` + """ + for column_key, column in sa.inspect(obj.__class__).columns.items(): + if column.primary_key: + continue + + if not (getattr(obj, column_key) == getattr(obj2, column_key)): + return False + return True diff --git a/python-sqlalchemy-utils/sqlalchemy_utils/functions/render.py b/python-sqlalchemy-utils/sqlalchemy_utils/functions/render.py new file mode 100644 index 0000000..b6c6054 --- /dev/null +++ b/python-sqlalchemy-utils/sqlalchemy_utils/functions/render.py @@ -0,0 +1,72 @@ +import inspect + +import six +import sqlalchemy as sa + +from .mock import create_mock_engine + + +def render_expression(expression, bind, stream=None): + """Generate a SQL expression from the passed python expression. + + Only the global variable, `engine`, is available for use in the + expression. Additional local variables may be passed in the context + parameter. + + Note this function is meant for convenience and protected usage. Do NOT + blindly pass user input to this function as it uses exec. + + :param bind: A SQLAlchemy engine or bind URL. + :param stream: Render all DDL operations to the stream. + """ + + # Create a stream if not present. + + if stream is None: + stream = six.moves.cStringIO() + + engine = create_mock_engine(bind, stream) + + # Navigate the stack and find the calling frame that allows the + # expression to execuate. + + for frame in inspect.stack()[1:]: + try: + frame = frame[0] + local = dict(frame.f_locals) + local['engine'] = engine + six.exec_(expression, frame.f_globals, local) + break + except: + pass + else: + raise ValueError('Not a valid python expression', engine) + + return stream + + +def render_statement(statement, bind=None): + """ + Generate an SQL expression string with bound parameters rendered inline + for the given SQLAlchemy statement. + + :param statement: SQLAlchemy Query object. + :param bind: + Optional SQLAlchemy bind, if None uses the bind of the given query + object. + """ + + if isinstance(statement, sa.orm.query.Query): + if bind is None: + bind = statement.session.get_bind(statement._mapper_zero()) + + statement = statement.statement + + elif bind is None: + bind = statement.bind + + stream = six.moves.cStringIO() + engine = create_mock_engine(bind.engine, stream=stream) + engine.execute(statement) + + return stream.getvalue() diff --git a/python-sqlalchemy-utils/sqlalchemy_utils/functions/sort_query.py b/python-sqlalchemy-utils/sqlalchemy_utils/functions/sort_query.py new file mode 100644 index 0000000..f233c4f --- /dev/null +++ b/python-sqlalchemy-utils/sqlalchemy_utils/functions/sort_query.py @@ -0,0 +1,199 @@ +import sqlalchemy as sa +from sqlalchemy.sql.expression import asc, desc + +from .database import has_unique_index +from .orm import get_query_descriptor, get_tables + + +class QuerySorterException(Exception): + pass + + +class QuerySorter(object): + def __init__(self, silent=True, separator='-'): + self.separator = separator + self.silent = silent + + def assign_order_by(self, entity, attr, func): + expr = get_query_descriptor(self.query, entity, attr) + + if expr is not None: + return self.query.order_by(func(expr)) + if not self.silent: + raise QuerySorterException( + "Could not sort query with expression '%s'" % attr + ) + return self.query + + def parse_sort_arg(self, arg): + if arg[0] == self.separator: + func = desc + arg = arg[1:] + else: + func = asc + + parts = arg.split(self.separator) + return { + 'entity': parts[0] if len(parts) > 1 else None, + 'attr': parts[1] if len(parts) > 1 else arg, + 'func': func + } + + def __call__(self, query, *args): + self.query = query + + for sort in args: + if not sort: + continue + self.query = self.assign_order_by( + **self.parse_sort_arg(sort) + ) + return self.query + + +def sort_query(query, *args, **kwargs): + """ + Applies an sql ORDER BY for given query. This function can be easily used + with user-defined sorting. + + The examples use the following model definition: + + :: + + + import sqlalchemy as sa + from sqlalchemy import create_engine + from sqlalchemy.orm import sessionmaker + from sqlalchemy.ext.declarative import declarative_base + from sqlalchemy_utils import sort_query + + + engine = create_engine( + 'sqlite:///' + ) + Base = declarative_base() + Session = sessionmaker(bind=engine) + session = Session() + + class Category(Base): + __tablename__ = 'category' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + class Article(Base): + __tablename__ = 'article' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + category_id = sa.Column(sa.Integer, sa.ForeignKey(Category.id)) + + category = sa.orm.relationship( + Category, primaryjoin=category_id == Category.id + ) + + + + 1. Applying simple ascending sort + :: + + + query = session.query(Article) + query = sort_query(query, 'name') + + + 2. Appying descending sort + :: + + + query = sort_query(query, '-name') + + 3. Applying sort to custom calculated label + :: + + + query = session.query( + Category, sa.func.count(Article.id).label('articles') + ) + query = sort_query(query, 'articles') + + 4. Applying sort to joined table column + :: + + + query = session.query(Article).join(Article.category) + query = sort_query(query, 'category-name') + + + :param query: + query to be modified + :param sort: + string that defines the label or column to sort the query by + :param silent: + Whether or not to raise exceptions if unknown sort column + is passed. By default this is `True` indicating that no errors should + be raised for unknown columns. + """ + return QuerySorter(**kwargs)(query, *args) + + +def make_order_by_deterministic(query): + """ + Make query order by deterministic (if it isn't already). Order by is + considered deterministic if it contains column that is unique index ( + either it is a primary key or has a unique index). Many times it is design + flaw to order by queries in nondeterministic manner. + + Consider a User model with three fields: id (primary key), favorite color + and email (unique).:: + + + from sqlalchemy_utils import make_order_by_deterministic + + + query = session.query(User).order_by(User.favorite_color) + + query = make_order_by_deterministic(query) + print query # 'SELECT ... ORDER BY "user".favorite_color, "user".id' + + + query = session.query(User).order_by(User.email) + + query = make_order_by_deterministic(query) + print query # 'SELECT ... ORDER BY "user".email' + + + query = session.query(User).order_by(User.id) + + query = make_order_by_deterministic(query) + print query # 'SELECT ... ORDER BY "user".id' + + + .. versionadded: 0.27.1 + """ + order_by_func = sa.asc + + if not query._order_by: + column = None + else: + order_by = query._order_by[0] + if isinstance(order_by, sa.sql.expression.UnaryExpression): + if order_by.modifier == sa.sql.operators.desc_op: + order_by_func = sa.desc + else: + order_by_func = sa.asc + column = order_by.get_children()[0] + else: + column = order_by + + # Skip queries that are ordered by an already deterministic column + if isinstance(column, sa.Column): + try: + if has_unique_index(column): + return query + except TypeError: + pass + + base_table = get_tables(query._entities[0])[0] + query = query.order_by( + *(order_by_func(c) for c in base_table.c if c.primary_key) + ) + return query diff --git a/python-sqlalchemy-utils/sqlalchemy_utils/generic.py b/python-sqlalchemy-utils/sqlalchemy_utils/generic.py new file mode 100644 index 0000000..e466316 --- /dev/null +++ b/python-sqlalchemy-utils/sqlalchemy_utils/generic.py @@ -0,0 +1,182 @@ +from collections import Iterable + +import six +import sqlalchemy as sa +from sqlalchemy.ext.hybrid import hybrid_property +from sqlalchemy.orm import attributes, class_mapper, ColumnProperty +from sqlalchemy.orm.interfaces import MapperProperty, PropComparator +from sqlalchemy.orm.session import _state_session +from sqlalchemy.util import set_creation_order + +from sqlalchemy_utils.functions import identity + +from .exceptions import ImproperlyConfigured + + +class GenericAttributeImpl(attributes.ScalarAttributeImpl): + def get(self, state, dict_, passive=attributes.PASSIVE_OFF): + if self.key in dict_: + return dict_[self.key] + + # Retrieve the session bound to the state in order to perform + # a lazy query for the attribute. + session = _state_session(state) + if session is None: + # State is not bound to a session; we cannot proceed. + return None + + # Find class for discriminator. + # TODO: Perhaps optimize with some sort of lookup? + discriminator = self.get_state_discriminator(state) + target_class = state.class_._decl_class_registry.get(discriminator) + + if target_class is None: + # Unknown discriminator; return nothing. + return None + + id = self.get_state_id(state) + + target = session.query(target_class).get(id) + + # Return found (or not found) target. + return target + + def get_state_discriminator(self, state): + discriminator = self.parent_token.discriminator + if isinstance(discriminator, hybrid_property): + return getattr(state.obj(), discriminator.__name__) + else: + return state.attrs[discriminator.key].value + + def get_state_id(self, state): + # Lookup row with the discriminator and id. + return tuple(state.attrs[id.key].value for id in self.parent_token.id) + + def set(self, state, dict_, initiator, + passive=attributes.PASSIVE_OFF, + check_old=None, + pop=False): + + # Set us on the state. + dict_[self.key] = initiator + + if initiator is None: + # Nullify relationship args + for id in self.parent_token.id: + dict_[id.key] = None + dict_[self.parent_token.discriminator.key] = None + else: + # Get the primary key of the initiator and ensure we + # can support this assignment. + class_ = type(initiator) + mapper = class_mapper(class_) + + pk = mapper.identity_key_from_instance(initiator)[1] + + # Set the identifier and the discriminator. + discriminator = six.text_type(class_.__name__) + + for index, id in enumerate(self.parent_token.id): + dict_[id.key] = pk[index] + dict_[self.parent_token.discriminator.key] = discriminator + + +class GenericRelationshipProperty(MapperProperty): + """A generic form of the relationship property. + + Creates a 1 to many relationship between the parent model + and any other models using a descriminator (the table name). + + :param discriminator + Field to discriminate which model we are referring to. + :param id: + Field to point to the model we are referring to. + """ + + def __init__(self, discriminator, id, doc=None): + super(GenericRelationshipProperty, self).__init__() + self._discriminator_col = discriminator + self._id_cols = id + self._id = None + self._discriminator = None + self.doc = doc + + set_creation_order(self) + + def _column_to_property(self, column): + if isinstance(column, hybrid_property): + attr_key = column.__name__ + for key, attr in self.parent.all_orm_descriptors.items(): + if key == attr_key: + return attr + else: + for key, attr in self.parent.attrs.items(): + if isinstance(attr, ColumnProperty): + if attr.columns[0].name == column.name: + return attr + + def init(self): + def convert_strings(column): + if isinstance(column, six.string_types): + return self.parent.columns[column] + return column + + self._discriminator_col = convert_strings(self._discriminator_col) + self._id_cols = convert_strings(self._id_cols) + + if isinstance(self._id_cols, Iterable): + self._id_cols = list(map(convert_strings, self._id_cols)) + else: + self._id_cols = [self._id_cols] + + self.discriminator = self._column_to_property(self._discriminator_col) + + if self.discriminator is None: + raise ImproperlyConfigured( + 'Could not find discriminator descriptor.' + ) + + self.id = list(map(self._column_to_property, self._id_cols)) + + class Comparator(PropComparator): + def __init__(self, prop, parentmapper): + self.property = prop + self._parententity = parentmapper + + def __eq__(self, other): + discriminator = six.text_type(type(other).__name__) + q = self.property._discriminator_col == discriminator + other_id = identity(other) + for index, id in enumerate(self.property._id_cols): + q &= id == other_id[index] + return q + + def __ne__(self, other): + return ~(self == other) + + def is_type(self, other): + mapper = sa.inspect(other) + # Iterate through the weak sequence in order to get the actual + # mappers + class_names = [six.text_type(other.__name__)] + class_names.extend([ + six.text_type(submapper.class_.__name__) + for submapper in mapper._inheriting_mappers + ]) + + return self.property._discriminator_col.in_(class_names) + + def instrument_class(self, mapper): + attributes.register_attribute( + mapper.class_, + self.key, + comparator=self.Comparator(self, mapper), + parententity=mapper, + doc=self.doc, + impl_class=GenericAttributeImpl, + parent_token=self + ) + + +def generic_relationship(*args, **kwargs): + return GenericRelationshipProperty(*args, **kwargs) diff --git a/python-sqlalchemy-utils/sqlalchemy_utils/i18n.py b/python-sqlalchemy-utils/sqlalchemy_utils/i18n.py new file mode 100644 index 0000000..06039e3 --- /dev/null +++ b/python-sqlalchemy-utils/sqlalchemy_utils/i18n.py @@ -0,0 +1,109 @@ +import six +import sqlalchemy as sa +from sqlalchemy.ext.compiler import compiles +from sqlalchemy.ext.hybrid import hybrid_property +from sqlalchemy.sql.expression import ColumnElement + +from .exceptions import ImproperlyConfigured + +try: + import babel +except ImportError: + babel = None + +try: + from flask.ext.babel import get_locale +except ImportError: + def get_locale(): + raise ImproperlyConfigured( + 'Could not load get_locale function from Flask-Babel. Either ' + 'install Flask-Babel or make a similar function and override it ' + 'in this module.' + ) + + +def cast_locale(obj, locale): + """ + Cast given locale to string. Supports also callbacks that return locales. + + :param obj: + Object or class to use as a possible parameter to locale callable + :param locale: + Locale object or string or callable that returns a locale. + """ + if callable(locale): + try: + locale = locale() + except TypeError: + locale = locale(obj) + if isinstance(locale, babel.Locale): + return str(locale) + return locale + + +class cast_locale_expr(ColumnElement): + def __init__(self, cls, locale): + self.cls = cls + self.locale = locale + + +@compiles(cast_locale_expr) +def compile_cast_locale_expr(element, compiler, **kw): + locale = cast_locale(element.cls, element.locale) + if isinstance(locale, six.string_types): + return "'{0}'".format(locale) + return compiler.process(locale) + + +class TranslationHybrid(object): + def __init__(self, current_locale, default_locale, default_value=None): + if babel is None: + raise ImproperlyConfigured( + 'You need to install babel in order to use TranslationHybrid.' + ) + self.current_locale = current_locale + self.default_locale = default_locale + self.default_value = default_value + + def getter_factory(self, attr): + """ + Return a hybrid_property getter function for given attribute. The + returned getter first checks if object has translation for current + locale. If not it tries to get translation for default locale. If there + is no translation found for default locale it returns None. + """ + def getter(obj): + current_locale = cast_locale(obj, self.current_locale) + try: + return getattr(obj, attr.key)[current_locale] + except (TypeError, KeyError): + default_locale = cast_locale( + obj, self.default_locale + ) + try: + return getattr(obj, attr.key)[default_locale] + except (TypeError, KeyError): + return self.default_value + return getter + + def setter_factory(self, attr): + def setter(obj, value): + if getattr(obj, attr.key) is None: + setattr(obj, attr.key, {}) + locale = cast_locale(obj, self.current_locale) + getattr(obj, attr.key)[locale] = value + return setter + + def expr_factory(self, attr): + def expr(cls): + current_locale = cast_locale_expr(cls, self.current_locale) + default_locale = cast_locale_expr(cls, self.default_locale) + return sa.func.coalesce(attr[current_locale], attr[default_locale]) + return expr + + def __call__(self, attr): + return hybrid_property( + fget=self.getter_factory(attr), + fset=self.setter_factory(attr), + expr=self.expr_factory(attr) + ) diff --git a/python-sqlalchemy-utils/sqlalchemy_utils/listeners.py b/python-sqlalchemy-utils/sqlalchemy_utils/listeners.py new file mode 100644 index 0000000..ce6d774 --- /dev/null +++ b/python-sqlalchemy-utils/sqlalchemy_utils/listeners.py @@ -0,0 +1,255 @@ +import sqlalchemy as sa + +from .exceptions import ImproperlyConfigured + + +def coercion_listener(mapper, class_): + """ + Auto assigns coercing listener for all class properties which are of coerce + capable type. + """ + for prop in mapper.iterate_properties: + try: + listener = prop.columns[0].type.coercion_listener + except AttributeError: + continue + sa.event.listen( + getattr(class_, prop.key), + 'set', + listener, + retval=True + ) + + +def instant_defaults_listener(target, args, kwargs): + for key, column in sa.inspect(target.__class__).columns.items(): + if column.default is not None: + if callable(column.default.arg): + setattr(target, key, column.default.arg(target)) + else: + setattr(target, key, column.default.arg) + + +def force_auto_coercion(mapper=None): + """ + Function that assigns automatic data type coercion for all classes which + are of type of given mapper. The coercion is applied to all coercion + capable properties. By default coercion is applied to all SQLAlchemy + mappers. + + Before initializing your models you need to call force_auto_coercion. + + :: + + from sqlalchemy_utils import force_auto_coercion + + + force_auto_coercion() + + + Then define your models the usual way:: + + + class Document(Base): + __tablename__ = 'document' + id = sa.Column(sa.Integer, autoincrement=True) + name = sa.Column(sa.Unicode(50)) + background_color = sa.Column(ColorType) + + + Now scalar values for coercion capable data types will convert to + appropriate value objects:: + + document = Document() + document.background_color = 'F5F5F5' + document.background_color # Color object + session.commit() + + + :param mapper: The mapper which the automatic data type coercion should be + applied to + """ + if mapper is None: + mapper = sa.orm.mapper + sa.event.listen(mapper, 'mapper_configured', coercion_listener) + + +def force_instant_defaults(mapper=None): + """ + Function that assigns object column defaults on object initialization + time. By default calling this function applies instant defaults to all + your models. + + Setting up instant defaults:: + + + from sqlalchemy_utils import force_instant_defaults + + + force_instant_defaults() + + Example usage:: + + + class Document(Base): + __tablename__ = 'document' + id = sa.Column(sa.Integer, autoincrement=True) + name = sa.Column(sa.Unicode(50)) + created_at = sa.Column(sa.DateTime, default=datetime.now) + + + document = Document() + document.created_at # datetime object + + + :param mapper: The mapper which the automatic instant defaults forcing + should be applied to + """ + if mapper is None: + mapper = sa.orm.mapper + sa.event.listen(mapper, 'init', instant_defaults_listener) + + +def auto_delete_orphans(attr): + """ + Delete orphans for given SQLAlchemy model attribute. This function can be + used for deleting many-to-many associated orphans easily. For more + information see + https://bitbucket.org/zzzeek/sqlalchemy/wiki/UsageRecipes/ManyToManyOrphan. + + Consider the following model definition: + + :: + + from sqlalchemy.ext.associationproxy import association_proxy + from sqlalchemy import * + from sqlalchemy.orm import * + from sqlalchemy.ext.declarative import declarative_base + from sqlalchemy import event + + + Base = declarative_base() + + tagging = Table( + 'tagging', + Base.metadata, + Column( + 'tag_id', + Integer, + ForeignKey('tag.id', ondelete='CASCADE'), + primary_key=True + ), + Column( + 'entry_id', + Integer, + ForeignKey('entry.id', ondelete='CASCADE'), + primary_key=True + ) + ) + + class Tag(Base): + __tablename__ = 'tag' + id = Column(Integer, primary_key=True) + name = Column(String(100), unique=True, nullable=False) + + def __init__(self, name=None): + self.name = name + + class Entry(Base): + __tablename__ = 'entry' + + id = Column(Integer, primary_key=True) + + tags = relationship( + 'Tag', + secondary=tagging, + backref='entries' + ) + + Now lets say we want to delete the tags if all their parents get deleted ( + all Entry objects get deleted). This can be achieved as follows: + + :: + + + from sqlalchemy_utils import auto_delete_orphans + + + auto_delete_orphans(Entry.tags) + + + After we've set up this listener we can see it in action. + + :: + + + e = create_engine('sqlite://') + + Base.metadata.create_all(e) + + s = Session(e) + + r1 = Entry() + r2 = Entry() + r3 = Entry() + t1, t2, t3, t4 = Tag('t1'), Tag('t2'), Tag('t3'), Tag('t4') + + r1.tags.extend([t1, t2]) + r2.tags.extend([t2, t3]) + r3.tags.extend([t4]) + s.add_all([r1, r2, r3]) + + assert s.query(Tag).count() == 4 + + r2.tags.remove(t2) + + assert s.query(Tag).count() == 4 + + r1.tags.remove(t2) + + assert s.query(Tag).count() == 3 + + r1.tags.remove(t1) + + assert s.query(Tag).count() == 2 + + .. versionadded: 0.26.4 + + :param attr: Association relationship attribute to auto delete orphans from + """ + + parent_class = attr.parent.class_ + target_class = attr.property.mapper.class_ + + backref = attr.property.backref + if not backref: + raise ImproperlyConfigured( + 'The relationship argument given for auto_delete_orphans needs to ' + 'have a backref relationship set.' + ) + + @sa.event.listens_for(sa.orm.Session, 'after_flush') + def delete_orphan_listener(session, ctx): + # Look through Session state to see if we want to emit a DELETE for + # orphans + orphans_found = ( + any( + isinstance(obj, parent_class) and + sa.orm.attributes.get_history(obj, attr.key).deleted + for obj in session.dirty + ) or + any( + isinstance(obj, parent_class) + for obj in session.deleted + ) + ) + + if orphans_found: + # Emit a DELETE for all orphans + ( + session.query(target_class) + .filter( + ~getattr(target_class, attr.property.backref).any() + ) + .delete(synchronize_session=False) + ) diff --git a/python-sqlalchemy-utils/sqlalchemy_utils/models.py b/python-sqlalchemy-utils/sqlalchemy_utils/models.py new file mode 100644 index 0000000..80a85ec --- /dev/null +++ b/python-sqlalchemy-utils/sqlalchemy_utils/models.py @@ -0,0 +1,33 @@ +from datetime import datetime + +import sqlalchemy as sa + + +class Timestamp(object): + """Adds `created` and `updated` columns to a derived declarative model. + + The `created` column is handled through a default and the `updated` + column is handled through a `before_update` event that propagates + for all derived declarative models. + + :: + + + import sqlalchemy as sa + from sqlalchemy_utils import Timestamp + + + class SomeModel(Base, Timestamp): + __tablename__ = 'somemodel' + id = sa.Column(sa.Integer, primary_key=True) + """ + + created = sa.Column(sa.DateTime, default=datetime.utcnow, nullable=False) + updated = sa.Column(sa.DateTime, default=datetime.utcnow, nullable=False) + + +@sa.event.listens_for(Timestamp, 'before_update', propagate=True) +def timestamp_before_update(mapper, connection, target): + # When a model with a timestamp is updated; force update the updated + # timestamp. + target.updated = datetime.utcnow() diff --git a/python-sqlalchemy-utils/sqlalchemy_utils/observer.py b/python-sqlalchemy-utils/sqlalchemy_utils/observer.py new file mode 100644 index 0000000..c109480 --- /dev/null +++ b/python-sqlalchemy-utils/sqlalchemy_utils/observer.py @@ -0,0 +1,333 @@ +""" +This module provides a decorator function for observing changes in a given +property. Internally the decorator is implemented using SQLAlchemy event +listeners. Both column properties and relationship properties can be observed. + +Property observers can be used for pre-calculating aggregates and automatic +real-time data denormalization. + +Simple observers +---------------- + +At the heart of the observer extension is the :func:`observes` decorator. You +mark some property path as being observed and the marked method will get +notified when any changes are made to given path. + +Consider the following model structure: + +:: + + class Director(Base): + __tablename__ = 'director' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.String) + date_of_birth = sa.Column(sa.Date) + + class Movie(Base): + __tablename__ = 'movie' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.String) + director_id = sa.Column(sa.Integer, sa.ForeignKey(Director.id)) + director = sa.orm.relationship(Director, backref='movies') + + +Now consider we want to show movies in some listing ordered by director id +first and movie id secondly. If we have many movies then using joins and +ordering by Director.name will be very slow. Here is where denormalization +and :func:`observes` comes to rescue the day. Let's add a new column called +director_name to Movie which will get automatically copied from associated +Director. + + +:: + + from sqlalchemy_utils import observes + + + class Movie(Base): + # same as before.. + director_name = sa.Column(sa.String) + + @observes('director') + def director_observer(self, director): + self.director_name = director.name + +.. note:: + + This example could be done much more efficiently using a compound foreign + key from director_name, director_id to Director.name, Director.id but for + the sake of simplicity we added this as an example. + + +Observes vs aggregated +---------------------- + +:func:`observes` and :func:`.aggregates.aggregated` can be used for similar +things. However performance wise you should take the following things into +consideration: + +* :func:`observes` works always inside transaction and deals with objects. If + the relationship observer is observing has a large number of objects it's + better to use :func:`.aggregates.aggregated`. +* :func:`.aggregates.aggregated` always executes one additional query per + aggregate so in scenarios where the observed relationship has only a handful + of objects it's better to use :func:`observes` instead. + + +Example 1. Movie with many ratings + +Let's say we have a Movie object with potentially thousands of ratings. In this +case we should always use :func:`.aggregates.aggregated` since iterating +through thousands of objects is slow and very memory consuming. + +Example 2. Product with denormalized catalog name + +Each product belongs to one catalog. Here it is natural to use :func:`observes` +for data denormalization. + + +Deeply nested observing +----------------------- + +Consider the following model structure where Catalog has many Categories and +Category has many Products. + +:: + + class Catalog(Base): + __tablename__ = 'catalog' + id = sa.Column(sa.Integer, primary_key=True) + product_count = sa.Column(sa.Integer, default=0) + + @observes('categories.products') + def product_observer(self, products): + self.product_count = len(products) + + categories = sa.orm.relationship('Category', backref='catalog') + + class Category(Base): + __tablename__ = 'category' + id = sa.Column(sa.Integer, primary_key=True) + catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id')) + + products = sa.orm.relationship('Product', backref='category') + + class Product(Base): + __tablename__ = 'product' + id = sa.Column(sa.Integer, primary_key=True) + price = sa.Column(sa.Numeric) + + category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id')) + + +:func:`observes` is smart enough to: + +* Notify catalog objects of any changes in associated Product objects +* Notify catalog objects of any changes in Category objects that affect + products (for example if Category gets deleted, or a new Category is added to + Catalog with any number of Products) + + +:: + + category = Category( + products=[Product(), Product()] + ) + category2 = Category( + product=[Product()] + ) + + catalog = Catalog( + categories=[category, category2] + ) + session.add(catalog) + session.commit() + catalog.product_count # 2 + + session.delete(category) + session.commit() + catalog.product_count # 1 + +""" +import itertools +from collections import defaultdict, Iterable, namedtuple + +import sqlalchemy as sa + +from sqlalchemy_utils.functions import getdotattr +from sqlalchemy_utils.path import AttrPath +from sqlalchemy_utils.utils import is_sequence + +Callback = namedtuple('Callback', ['func', 'path', 'backref', 'fullpath']) + + +class PropertyObserver(object): + def __init__(self): + self.listener_args = [ + ( + sa.orm.mapper, + 'mapper_configured', + self.update_generator_registry + ), + ( + sa.orm.mapper, + 'after_configured', + self.gather_paths + ), + ( + sa.orm.session.Session, + 'before_flush', + self.invoke_callbacks + ) + ] + self.callback_map = defaultdict(list) + # TODO: make the registry a WeakKey dict + self.generator_registry = defaultdict(list) + + def remove_listeners(self): + for args in self.listener_args: + sa.event.remove(*args) + + def register_listeners(self): + for args in self.listener_args: + if not sa.event.contains(*args): + sa.event.listen(*args) + + def __repr__(self): + return '' + + def update_generator_registry(self, mapper, class_): + """ + Adds generator functions to generator_registry. + """ + + for generator in class_.__dict__.values(): + if hasattr(generator, '__observes__'): + self.generator_registry[class_].append( + generator + ) + + def gather_paths(self): + for class_, callbacks in self.generator_registry.items(): + for callback in callbacks: + path = AttrPath(class_, callback.__observes__) + + self.callback_map[class_].append( + Callback( + func=callback, + path=path, + backref=None, + fullpath=path + ) + ) + + for index in range(len(path)): + i = index + 1 + prop = path[index].property + if isinstance(prop, sa.orm.RelationshipProperty): + prop_class = path[index].property.mapper.class_ + self.callback_map[prop_class].append( + Callback( + func=callback, + path=path[i:], + backref=~ (path[:i]), + fullpath=path + ) + ) + + def gather_callback_args(self, obj, callbacks): + session = sa.orm.object_session(obj) + for callback in callbacks: + backref = callback.backref + + root_objs = getdotattr(obj, backref) if backref else obj + if root_objs: + if not isinstance(root_objs, Iterable): + root_objs = [root_objs] + + for root_obj in root_objs: + objects = getdotattr( + root_obj, + callback.fullpath, + lambda obj: obj not in session.deleted + ) + + yield ( + root_obj, + callback.func, + objects + ) + + def changed_objects(self, session): + objs = itertools.chain(session.new, session.dirty, session.deleted) + for obj in objs: + for class_, callbacks in self.callback_map.items(): + if isinstance(obj, class_): + yield obj, callbacks + + def invoke_callbacks(self, session, ctx, instances): + callback_args = defaultdict(lambda: defaultdict(set)) + for obj, callbacks in self.changed_objects(session): + args = self.gather_callback_args(obj, callbacks) + for (root_obj, func, objects) in args: + if is_sequence(objects): + callback_args[root_obj][func] = ( + callback_args[root_obj][func] | set(objects) + ) + else: + callback_args[root_obj][func] = objects + + for root_obj, callback_objs in callback_args.items(): + for callback, objs in callback_objs.items(): + callback(root_obj, objs) + +observer = PropertyObserver() + + +def observes(path, observer=observer): + """ + Mark method as property observer for the given property path. Inside + transaction observer gathers all changes made in given property path and + feeds the changed objects to observer-marked method at the before flush + phase. + + :: + + from sqlalchemy_utils import observes + + + class Catalog(Base): + __tablename__ = 'catalog' + id = sa.Column(sa.Integer, primary_key=True) + category_count = sa.Column(sa.Integer, default=0) + + @observes('categories') + def category_observer(self, categories): + self.category_count = len(categories) + + class Category(Base): + __tablename__ = 'category' + id = sa.Column(sa.Integer, primary_key=True) + catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id')) + + + catalog = Catalog(categories=[Category(), Category()]) + session.add(catalog) + session.commit() + + catalog.category_count # 2 + + + .. versionadded: 0.28.0 + + :param path: Dot-notated property path, eg. 'categories.products.price' + :param observer: :meth:`PropertyObserver` object + """ + observer.register_listeners() + + def wraps(func): + def wrapper(self, *args, **kwargs): + return func(self, *args, **kwargs) + wrapper.__observes__ = path + return wrapper + return wraps diff --git a/python-sqlalchemy-utils/sqlalchemy_utils/operators.py b/python-sqlalchemy-utils/sqlalchemy_utils/operators.py new file mode 100644 index 0000000..3092763 --- /dev/null +++ b/python-sqlalchemy-utils/sqlalchemy_utils/operators.py @@ -0,0 +1,74 @@ +import sqlalchemy as sa + + +def inspect_type(mixed): + if isinstance(mixed, sa.orm.attributes.InstrumentedAttribute): + return mixed.property.columns[0].type + elif isinstance(mixed, sa.orm.ColumnProperty): + return mixed.columns[0].type + elif isinstance(mixed, sa.Column): + return mixed.type + + +def is_case_insensitive(mixed): + try: + return isinstance( + inspect_type(mixed).comparator, + CaseInsensitiveComparator + ) + except AttributeError: + try: + return issubclass( + inspect_type(mixed).comparator_factory, + CaseInsensitiveComparator + ) + except AttributeError: + return False + + +class CaseInsensitiveComparator(sa.Unicode.Comparator): + @classmethod + def lowercase_arg(cls, func): + def operation(self, other, **kwargs): + operator = getattr(sa.Unicode.Comparator, func) + if other is None: + return operator(self, other, **kwargs) + if not is_case_insensitive(other): + other = sa.func.lower(other) + return operator(self, other, **kwargs) + return operation + + def in_(self, other): + if isinstance(other, list) or isinstance(other, tuple): + other = map(sa.func.lower, other) + return sa.Unicode.Comparator.in_(self, other) + + def notin_(self, other): + if isinstance(other, list) or isinstance(other, tuple): + other = map(sa.func.lower, other) + return sa.Unicode.Comparator.notin_(self, other) + + +string_operator_funcs = [ + '__eq__', + '__ne__', + '__lt__', + '__le__', + '__gt__', + '__ge__', + 'concat', + 'contains', + 'ilike', + 'like', + 'notlike', + 'notilike', + 'startswith', + 'endswith', +] + +for func in string_operator_funcs: + setattr( + CaseInsensitiveComparator, + func, + CaseInsensitiveComparator.lowercase_arg(func) + ) diff --git a/python-sqlalchemy-utils/sqlalchemy_utils/path.py b/python-sqlalchemy-utils/sqlalchemy_utils/path.py new file mode 100644 index 0000000..57ec694 --- /dev/null +++ b/python-sqlalchemy-utils/sqlalchemy_utils/path.py @@ -0,0 +1,154 @@ +import sqlalchemy as sa +from sqlalchemy.orm.attributes import InstrumentedAttribute +from sqlalchemy.util.langhelpers import symbol + +from .utils import str_coercible + + +@str_coercible +class Path(object): + def __init__(self, path, separator='.'): + if isinstance(path, Path): + self.path = path.path + else: + self.path = path + self.separator = separator + + @property + def parts(self): + return self.path.split(self.separator) + + def __iter__(self): + for part in self.parts: + yield part + + def __len__(self): + return len(self.parts) + + def __repr__(self): + return "%s('%s')" % (self.__class__.__name__, self.path) + + def index(self, element): + return self.parts.index(element) + + def __getitem__(self, slice): + result = self.parts[slice] + if isinstance(result, list): + return self.__class__( + self.separator.join(result), + separator=self.separator + ) + return result + + def __eq__(self, other): + return self.path == other.path and self.separator == other.separator + + def __ne__(self, other): + return not (self == other) + + def __unicode__(self): + return self.path + + +def get_attr(mixed, attr): + if isinstance(mixed, InstrumentedAttribute): + return getattr( + mixed.property.mapper.class_, + attr + ) + else: + return getattr(mixed, attr) + + +@str_coercible +class AttrPath(object): + def __init__(self, class_, path): + self.class_ = class_ + self.path = Path(path) + self.parts = [] + last_attr = class_ + for value in self.path: + last_attr = get_attr(last_attr, value) + self.parts.append(last_attr) + + def __iter__(self): + for part in self.parts: + yield part + + def __invert__(self): + def get_backref(part): + prop = part.property + backref = prop.backref or prop.back_populates + if backref is None: + raise Exception( + "Invert failed because property '%s' of class " + "%s has no backref." % ( + prop.key, + prop.parent.class_.__name__ + ) + ) + if isinstance(backref, tuple): + return backref[0] + else: + return backref + + if isinstance(self.parts[-1].property, sa.orm.ColumnProperty): + class_ = self.parts[-1].class_ + else: + class_ = self.parts[-1].mapper.class_ + + return self.__class__( + class_, + '.'.join(map(get_backref, reversed(self.parts))) + ) + + def index(self, element): + for index, el in enumerate(self.parts): + if el is element: + return index + + @property + def direction(self): + symbols = [part.property.direction for part in self.parts] + if symbol('MANYTOMANY') in symbols: + return symbol('MANYTOMANY') + elif symbol('MANYTOONE') in symbols and symbol('ONETOMANY') in symbols: + return symbol('MANYTOMANY') + return symbols[0] + + @property + def uselist(self): + return any(part.property.uselist for part in self.parts) + + def __getitem__(self, slice): + result = self.parts[slice] + if isinstance(result, list) and result: + if result[0] is self.parts[0]: + class_ = self.class_ + else: + class_ = result[0].parent.class_ + return self.__class__( + class_, + self.path[slice] + ) + else: + return result + + def __len__(self): + return len(self.path) + + def __repr__(self): + return "%s(%s, %r)" % ( + self.__class__.__name__, + self.class_.__name__, + self.path.path + ) + + def __eq__(self, other): + return self.path == other.path and self.class_ == other.class_ + + def __ne__(self, other): + return not (self == other) + + def __unicode__(self): + return str(self.path) diff --git a/python-sqlalchemy-utils/sqlalchemy_utils/primitives/__init__.py b/python-sqlalchemy-utils/sqlalchemy_utils/primitives/__init__.py new file mode 100644 index 0000000..71a5829 --- /dev/null +++ b/python-sqlalchemy-utils/sqlalchemy_utils/primitives/__init__.py @@ -0,0 +1,4 @@ +from .country import Country # noqa +from .currency import Currency # noqa +from .weekday import WeekDay # noqa +from .weekdays import WeekDays # noqa diff --git a/python-sqlalchemy-utils/sqlalchemy_utils/primitives/country.py b/python-sqlalchemy-utils/sqlalchemy_utils/primitives/country.py new file mode 100644 index 0000000..01e6261 --- /dev/null +++ b/python-sqlalchemy-utils/sqlalchemy_utils/primitives/country.py @@ -0,0 +1,98 @@ +import six + +from sqlalchemy_utils import i18n +from sqlalchemy_utils.utils import str_coercible + + +@str_coercible +class Country(object): + """ + Country class wraps a 2 to 3 letter country code. It provides various + convenience properties and methods. + + :: + + from babel import Locale + from sqlalchemy_utils import Country, i18n + + + # First lets add a locale getter for testing purposes + i18n.get_locale = lambda: Locale('en') + + + Country('FI').name # Finland + Country('FI').code # FI + + Country(Country('FI')).code # 'FI' + + Country always validates the given code. + + :: + + Country(None) # raises TypeError + + Country('UnknownCode') # raises ValueError + + + Country supports equality operators. + + :: + + Country('FI') == Country('FI') + Country('FI') != Country('US') + + + Country objects are hashable. + + + :: + + assert hash(Country('FI')) == hash('FI') + + """ + def __init__(self, code_or_country): + if isinstance(code_or_country, Country): + self.code = code_or_country.code + elif isinstance(code_or_country, six.string_types): + self.validate(code_or_country) + self.code = code_or_country + else: + raise TypeError( + "Country() argument must be a string or a country, not '{0}'" + .format( + type(code_or_country).__name__ + ) + ) + + @property + def name(self): + return i18n.get_locale().territories[self.code] + + @classmethod + def validate(self, code): + try: + i18n.babel.Locale('en').territories[code] + except KeyError: + raise ValueError( + 'Could not convert string to country code: {0}'.format(code) + ) + + def __eq__(self, other): + if isinstance(other, Country): + return self.code == other.code + elif isinstance(other, six.string_types): + return self.code == other + else: + return NotImplemented + + def __hash__(self): + return hash(self.code) + + def __ne__(self, other): + return not (self == other) + + def __repr__(self): + return '%s(%r)' % (self.__class__.__name__, self.code) + + def __unicode__(self): + return self.name diff --git a/python-sqlalchemy-utils/sqlalchemy_utils/primitives/currency.py b/python-sqlalchemy-utils/sqlalchemy_utils/primitives/currency.py new file mode 100644 index 0000000..d27f688 --- /dev/null +++ b/python-sqlalchemy-utils/sqlalchemy_utils/primitives/currency.py @@ -0,0 +1,108 @@ +# -*- coding: utf-8 -*- +import six + +from sqlalchemy_utils import i18n, ImproperlyConfigured +from sqlalchemy_utils.utils import str_coercible + + +@str_coercible +class Currency(object): + """ + Currency class wraps a 3-letter currency code. It provides various + convenience properties and methods. + + :: + + from babel import Locale + from sqlalchemy_utils import Currency, i18n + + + # First lets add a locale getter for testing purposes + i18n.get_locale = lambda: Locale('en') + + + Currency('USD').name # US Dollar + Currency('USD').symbol # $ + + Currency(Currency('USD')).code # 'USD' + + Currency always validates the given code. + + :: + + Currency(None) # raises TypeError + + Currency('UnknownCode') # raises ValueError + + + Currency supports equality operators. + + :: + + Currency('USD') == Currency('USD') + Currency('USD') != Currency('EUR') + + + Currencies are hashable. + + + :: + + len(set([Currency('USD'), Currency('USD')])) # 1 + + + """ + def __init__(self, code): + if i18n.babel is None: + raise ImproperlyConfigured( + "'babel' package is required in order to use Currency class." + ) + if isinstance(code, Currency): + self.code = code + elif isinstance(code, six.string_types): + self.validate(code) + self.code = code + else: + raise TypeError( + 'First argument given to Currency constructor should be ' + 'either an instance of Currency or valid three letter ' + 'currency code.' + ) + + @classmethod + def validate(self, code): + try: + i18n.babel.Locale('en').currencies[code] + except KeyError: + raise ValueError("{0}' is not valid currency code.") + + @property + def symbol(self): + return i18n.babel.numbers.get_currency_symbol( + self.code, + i18n.get_locale() + ) + + @property + def name(self): + return i18n.get_locale().currencies[self.code] + + def __eq__(self, other): + if isinstance(other, Currency): + return self.code == other.code + elif isinstance(other, six.string_types): + return self.code == other + else: + return NotImplemented + + def __ne__(self, other): + return not (self == other) + + def __hash__(self): + return hash(self.code) + + def __repr__(self): + return '%s(%r)' % (self.__class__.__name__, self.code) + + def __unicode__(self): + return self.code diff --git a/python-sqlalchemy-utils/sqlalchemy_utils/primitives/weekday.py b/python-sqlalchemy-utils/sqlalchemy_utils/primitives/weekday.py new file mode 100644 index 0000000..29a4443 --- /dev/null +++ b/python-sqlalchemy-utils/sqlalchemy_utils/primitives/weekday.py @@ -0,0 +1,58 @@ +# -*- coding: utf-8 -*- +try: + from functools import total_ordering +except ImportError: + # Python 2.6 port + from total_ordering import total_ordering +from sqlalchemy_utils import i18n +from sqlalchemy_utils.utils import str_coercible + + +@str_coercible +@total_ordering +class WeekDay(object): + NUM_WEEK_DAYS = 7 + + def __init__(self, index): + if not (0 <= index < self.NUM_WEEK_DAYS): + raise ValueError( + "index must be between 0 and %d" % self.NUM_WEEK_DAYS + ) + self.index = index + + def __eq__(self, other): + if isinstance(other, WeekDay): + return self.index == other.index + else: + return NotImplemented + + def __hash__(self): + return hash(self.index) + + def __lt__(self, other): + return self.position < other.position + + def __repr__(self): + return '%s(%r)' % (self.__class__.__name__, self.index) + + def __unicode__(self): + return self.name + + def get_name(self, width='wide', context='format'): + names = i18n.babel.dates.get_day_names( + width, + context, + i18n.get_locale() + ) + return names[self.index] + + @property + def name(self): + return self.get_name() + + @property + def position(self): + return ( + self.index - + i18n.get_locale().first_week_day + ) % self.NUM_WEEK_DAYS diff --git a/python-sqlalchemy-utils/sqlalchemy_utils/primitives/weekdays.py b/python-sqlalchemy-utils/sqlalchemy_utils/primitives/weekdays.py new file mode 100644 index 0000000..b6aedab --- /dev/null +++ b/python-sqlalchemy-utils/sqlalchemy_utils/primitives/weekdays.py @@ -0,0 +1,61 @@ +import six + +from sqlalchemy_utils.utils import str_coercible + +from .weekday import WeekDay + + +@str_coercible +class WeekDays(object): + def __init__(self, bit_string_or_week_days): + if isinstance(bit_string_or_week_days, six.string_types): + self._days = set() + + if len(bit_string_or_week_days) != WeekDay.NUM_WEEK_DAYS: + raise ValueError( + 'Bit string must be {0} characters long.'.format( + WeekDay.NUM_WEEK_DAYS + ) + ) + + for index, bit in enumerate(bit_string_or_week_days): + if bit not in '01': + raise ValueError( + 'Bit string may only contain zeroes and ones.' + ) + if bit == '1': + self._days.add(WeekDay(index)) + elif isinstance(bit_string_or_week_days, WeekDays): + self._days = bit_string_or_week_days._days + else: + self._days = set(bit_string_or_week_days) + + def __eq__(self, other): + if isinstance(other, WeekDays): + return self._days == other._days + elif isinstance(other, six.string_types): + return self.as_bit_string() == other + else: + return NotImplemented + + def __iter__(self): + for day in sorted(self._days): + yield day + + def __contains__(self, value): + return value in self._days + + def __repr__(self): + return '%s(%r)' % ( + self.__class__.__name__, + self.as_bit_string() + ) + + def __unicode__(self): + return u', '.join(six.text_type(day) for day in self) + + def as_bit_string(self): + return ''.join( + '1' if WeekDay(index) in self._days else '0' + for index in six.moves.xrange(WeekDay.NUM_WEEK_DAYS) + ) diff --git a/python-sqlalchemy-utils/sqlalchemy_utils/proxy_dict.py b/python-sqlalchemy-utils/sqlalchemy_utils/proxy_dict.py new file mode 100644 index 0000000..e6ef228 --- /dev/null +++ b/python-sqlalchemy-utils/sqlalchemy_utils/proxy_dict.py @@ -0,0 +1,84 @@ +import sqlalchemy as sa + + +class ProxyDict(object): + def __init__(self, parent, collection_name, mapping_attr): + self.parent = parent + self.collection_name = collection_name + self.child_class = mapping_attr.class_ + self.key_name = mapping_attr.key + self.cache = {} + + @property + def collection(self): + return getattr(self.parent, self.collection_name) + + def keys(self): + descriptor = getattr(self.child_class, self.key_name) + return [x[0] for x in self.collection.values(descriptor)] + + def __contains__(self, key): + if key in self.cache: + return self.cache[key] is not None + return self.fetch(key) is not None + + def has_key(self, key): + return self.__contains__(key) + + def fetch(self, key): + session = sa.orm.object_session(self.parent) + if session and sa.orm.util.has_identity(self.parent): + obj = self.collection.filter_by(**{self.key_name: key}).first() + self.cache[key] = obj + return obj + + def create_new_instance(self, key): + value = self.child_class(**{self.key_name: key}) + self.collection.append(value) + self.cache[key] = value + return value + + def __getitem__(self, key): + if key in self.cache: + if self.cache[key] is not None: + return self.cache[key] + else: + value = self.fetch(key) + if value: + return value + + return self.create_new_instance(key) + + def __setitem__(self, key, value): + try: + existing = self[key] + self.collection.remove(existing) + except KeyError: + pass + self.collection.append(value) + self.cache[key] = value + + +def proxy_dict(parent, collection_name, mapping_attr): + try: + parent._proxy_dicts + except AttributeError: + parent._proxy_dicts = {} + + try: + return parent._proxy_dicts[collection_name] + except KeyError: + parent._proxy_dicts[collection_name] = ProxyDict( + parent, + collection_name, + mapping_attr + ) + return parent._proxy_dicts[collection_name] + + +def expire_proxy_dicts(target, context): + if hasattr(target, '_proxy_dicts'): + target._proxy_dicts = {} + + +sa.event.listen(sa.orm.mapper, 'expire', expire_proxy_dicts) diff --git a/python-sqlalchemy-utils/sqlalchemy_utils/query_chain.py b/python-sqlalchemy-utils/sqlalchemy_utils/query_chain.py new file mode 100644 index 0000000..1596b43 --- /dev/null +++ b/python-sqlalchemy-utils/sqlalchemy_utils/query_chain.py @@ -0,0 +1,173 @@ +""" +QueryChain is a wrapper for sequence of queries. + + +Features: + + * Easy iteration for sequence of queries + * Limit, offset and count which are applied to all queries in the chain + * Smart __getitem__ support + + +Initialization +^^^^^^^^^^^^^^ + +QueryChain takes iterable of queries as first argument. Additionally limit and +offset parameters can be given + +:: + + chain = QueryChain([session.query(User), session.query(Article)]) + + chain = QueryChain( + [session.query(User), session.query(Article)], + limit=4 + ) + + +Simple iteration +^^^^^^^^^^^^^^^^ +:: + + chain = QueryChain([session.query(User), session.query(Article)]) + + for obj in chain: + print obj + + +Limit and offset +^^^^^^^^^^^^^^^^ + +Lets say you have 5 blog posts, 5 articles and 5 news items in your +database. + +:: + + chain = QueryChain( + [ + session.query(BlogPost), + session.query(Article), + session.query(NewsItem) + ], + limit=5 + ) + + list(chain) # all blog posts but not articles and news items + + + chain = chain.offset(4) + list(chain) # last blog post, and first four articles + + +Just like with original query object the limit and offset can be chained to +return a new QueryChain. + +:: + + chain = chain.limit(5).offset(7) + + +Chain slicing +^^^^^^^^^^^^^ + +:: + + chain = QueryChain( + [ + session.query(BlogPost), + session.query(Article), + session.query(NewsItem) + ] + ) + + chain[3:6] # New QueryChain with offset=3 and limit=6 + + +Count +^^^^^ + +Let's assume that there are five blog posts, five articles and five news +items in the database, and you have the following query chain:: + + chain = QueryChain( + [ + session.query(BlogPost), + session.query(Article), + session.query(NewsItem) + ] + ) + +You can then get the total number rows returned by the query chain +with :meth:`~QueryChain.count`:: + + >>> chain.count() + 15 + + +""" +from copy import copy + + +class QueryChain(object): + """ + QueryChain can be used as a wrapper for sequence of queries. + + :param queries: A sequence of SQLAlchemy Query objects + :param limit: Similar to normal query limit this parameter can be used for + limiting the number of results for the whole query chain. + :param offset: Similar to normal query offset this parameter can be used + for offsetting the query chain as a whole. + + .. versionadded: 0.26.0 + """ + def __init__(self, queries, limit=None, offset=None): + self.queries = queries + self._limit = limit + self._offset = offset + + def __iter__(self): + consumed = 0 + skipped = 0 + for query in self.queries: + query_copy = copy(query) + if self._limit: + query = query.limit(self._limit - consumed) + if self._offset: + query = query.offset(self._offset - skipped) + + obj_count = 0 + for obj in query: + consumed += 1 + obj_count += 1 + yield obj + + if not obj_count: + skipped += query_copy.count() + else: + skipped += obj_count + + def limit(self, value): + return self[:value] + + def offset(self, value): + return self[value:] + + def count(self): + """ + Return the total number of rows this QueryChain's queries would return. + """ + return sum(q.count() for q in self.queries) + + def __getitem__(self, key): + if isinstance(key, slice): + return self.__class__( + queries=self.queries, + limit=key.stop if key.stop is not None else self._limit, + offset=key.start if key.start is not None else self._offset + ) + else: + for obj in self[key:1]: + return obj + + def __repr__(self): + return '' % id(self) diff --git a/python-sqlalchemy-utils/sqlalchemy_utils/relationships/__init__.py b/python-sqlalchemy-utils/sqlalchemy_utils/relationships/__init__.py new file mode 100644 index 0000000..5bec6f5 --- /dev/null +++ b/python-sqlalchemy-utils/sqlalchemy_utils/relationships/__init__.py @@ -0,0 +1,2 @@ +from .chained_join import chained_join # noqa +from .select_aggregate import select_aggregate # noqa diff --git a/python-sqlalchemy-utils/sqlalchemy_utils/relationships/chained_join.py b/python-sqlalchemy-utils/sqlalchemy_utils/relationships/chained_join.py new file mode 100644 index 0000000..2508d1a --- /dev/null +++ b/python-sqlalchemy-utils/sqlalchemy_utils/relationships/chained_join.py @@ -0,0 +1,31 @@ +def chained_join(*relationships): + """ + Return a chained Join object for given relationships. + """ + property_ = relationships[0].property + + if property_.secondary is not None: + from_ = property_.secondary.join( + property_.mapper.class_.__table__, + property_.secondaryjoin + ) + else: + from_ = property_.mapper.class_.__table__ + for relationship in relationships[1:]: + prop = relationship.property + if prop.secondary is not None: + from_ = from_.join( + prop.secondary, + prop.primaryjoin + ) + + from_ = from_.join( + prop.mapper.class_, + prop.secondaryjoin + ) + else: + from_ = from_.join( + prop.mapper.class_, + prop.primaryjoin + ) + return from_ diff --git a/python-sqlalchemy-utils/sqlalchemy_utils/relationships/select_aggregate.py b/python-sqlalchemy-utils/sqlalchemy_utils/relationships/select_aggregate.py new file mode 100644 index 0000000..379b33f --- /dev/null +++ b/python-sqlalchemy-utils/sqlalchemy_utils/relationships/select_aggregate.py @@ -0,0 +1,51 @@ +import sqlalchemy as sa + + +def select_aggregate(agg_expr, relationships): + """ + Return a subquery for fetching an aggregate value of given aggregate + expression and given sequence of relationships. + + The returned aggregate query can be used when updating denormalized column + value with query such as: + + UPDATE table SET column = {aggregate_query} + WHERE {condition} + + :param agg_expr: + an expression to be selected, for example sa.func.count('1') + :param relationships: + Sequence of relationships to be used for building the aggregate + query. + """ + from_ = relationships[0].mapper.class_.__table__ + for relationship in relationships[0:-1]: + property_ = relationship.property + if property_.secondary is not None: + from_ = from_.join( + property_.secondary, + property_.secondaryjoin + ) + + from_ = ( + from_ + .join( + property_.parent.class_, + property_.primaryjoin + ) + ) + + prop = relationships[-1].property + condition = prop.primaryjoin + if prop.secondary is not None: + from_ = from_.join( + prop.secondary, + prop.secondaryjoin + ) + + query = sa.select( + [agg_expr], + from_obj=[from_] + ) + + return query.where(condition) diff --git a/python-sqlalchemy-utils/sqlalchemy_utils/types/__init__.py b/python-sqlalchemy-utils/sqlalchemy_utils/types/__init__.py new file mode 100644 index 0000000..272b001 --- /dev/null +++ b/python-sqlalchemy-utils/sqlalchemy_utils/types/__init__.py @@ -0,0 +1,52 @@ +from functools import wraps + +from sqlalchemy.orm.collections import InstrumentedList as _InstrumentedList + +from .arrow import ArrowType # noqa +from .choice import Choice, ChoiceType # noqa +from .color import ColorType # noqa +from .country import CountryType # noqa +from .currency import CurrencyType # noqa +from .email import EmailType # noqa +from .encrypted import EncryptedType # noqa +from .ip_address import IPAddressType # noqa +from .json import JSONType # noqa +from .locale import LocaleType # noqa +from .password import Password, PasswordType # noqa +from .pg_composite import ( # noqa + CompositeArray, + CompositeType, + register_composites, + remove_composite_listeners +) +from .phone_number import PhoneNumber, PhoneNumberType # noqa +from .range import ( # noqa + DateRangeType, + DateTimeRangeType, + IntRangeType, + NumericRangeType +) +from .scalar_list import ScalarListException, ScalarListType # noqa +from .timezone import TimezoneType # noqa +from .ts_vector import TSVectorType # noqa +from .url import URLType # noqa +from .uuid import UUIDType # noqa +from .weekdays import WeekDaysType # noqa + + +class InstrumentedList(_InstrumentedList): + """Enhanced version of SQLAlchemy InstrumentedList. Provides some + additional functionality.""" + + def any(self, attr): + return any(getattr(item, attr) for item in self) + + def all(self, attr): + return all(getattr(item, attr) for item in self) + + +def instrumented_list(f): + @wraps(f) + def wrapper(*args, **kwargs): + return InstrumentedList([item for item in f(*args, **kwargs)]) + return wrapper diff --git a/python-sqlalchemy-utils/sqlalchemy_utils/types/arrow.py b/python-sqlalchemy-utils/sqlalchemy_utils/types/arrow.py new file mode 100644 index 0000000..d5a3bd0 --- /dev/null +++ b/python-sqlalchemy-utils/sqlalchemy_utils/types/arrow.py @@ -0,0 +1,91 @@ +from __future__ import absolute_import + +from collections import Iterable +from datetime import datetime + +import six +from sqlalchemy import types + +from sqlalchemy_utils.exceptions import ImproperlyConfigured + +from .scalar_coercible import ScalarCoercible + +arrow = None +try: + import arrow +except: + pass + + +class ArrowType(types.TypeDecorator, ScalarCoercible): + """ + ArrowType provides way of saving Arrow_ objects into database. It + automatically changes Arrow_ objects to datetime objects on the way in and + datetime objects back to Arrow_ objects on the way out (when querying + database). ArrowType needs Arrow_ library installed. + + .. _Arrow: http://crsmithdev.com/arrow/ + + :: + + from datetime import datetime + from sqlalchemy_utils import ArrowType + import arrow + + + class Article(Base): + __tablename__ = 'article' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + created_at = sa.Column(ArrowType) + + + + article = Article(created_at=arrow.utcnow()) + + + As you may expect all the arrow goodies come available: + + :: + + + article.created_at = article.created_at.replace(hours=-1) + + article.created_at.humanize() + # 'an hour ago' + + """ + impl = types.DateTime + + def __init__(self, *args, **kwargs): + if not arrow: + raise ImproperlyConfigured( + "'arrow' package is required to use 'ArrowType'" + ) + + super(ArrowType, self).__init__(*args, **kwargs) + + def process_bind_param(self, value, dialect): + if value: + return self._coerce(value).to('UTC').naive + return value + + def process_result_value(self, value, dialect): + if value: + return arrow.get(value) + return value + + def _coerce(self, value): + if value is None: + return None + elif isinstance(value, six.string_types): + value = arrow.get(value) + elif isinstance(value, Iterable): + value = arrow.get(*value) + elif isinstance(value, datetime): + value = arrow.get(value) + return value + + @property + def python_type(self): + return self.impl.type.python_type diff --git a/python-sqlalchemy-utils/sqlalchemy_utils/types/bit.py b/python-sqlalchemy-utils/sqlalchemy_utils/types/bit.py new file mode 100644 index 0000000..73cbc0b --- /dev/null +++ b/python-sqlalchemy-utils/sqlalchemy_utils/types/bit.py @@ -0,0 +1,22 @@ +import sqlalchemy as sa +from sqlalchemy.dialects.postgresql import BIT + + +class BitType(sa.types.TypeDecorator): + """ + BitType offers way of saving BITs into database. + """ + impl = sa.types.BINARY + + def __init__(self, length=1, **kwargs): + self.length = length + sa.types.TypeDecorator.__init__(self, **kwargs) + + def load_dialect_impl(self, dialect): + # Use the native BIT type for drivers that has it. + if dialect.name == 'postgresql': + return dialect.type_descriptor(BIT(self.length)) + elif dialect.name == 'sqlite': + return dialect.type_descriptor(sa.String(self.length)) + else: + return dialect.type_descriptor(type(self.impl)(self.length)) diff --git a/python-sqlalchemy-utils/sqlalchemy_utils/types/choice.py b/python-sqlalchemy-utils/sqlalchemy_utils/types/choice.py new file mode 100644 index 0000000..5a7446c --- /dev/null +++ b/python-sqlalchemy-utils/sqlalchemy_utils/types/choice.py @@ -0,0 +1,227 @@ +import six +from sqlalchemy import types + +from ..exceptions import ImproperlyConfigured +from .scalar_coercible import ScalarCoercible + +try: + from enum import Enum +except ImportError: + Enum = None + + +class Choice(object): + def __init__(self, code, value): + self.code = code + self.value = value + + def __eq__(self, other): + if isinstance(other, Choice): + return self.code == other.code + return other == self.code + + def __ne__(self, other): + return not (self == other) + + def __unicode__(self): + return six.text_type(self.value) + + def __repr__(self): + return 'Choice(code={code}, value={value})'.format( + code=self.code, + value=self.value + ) + + +class ChoiceType(types.TypeDecorator, ScalarCoercible): + """ + ChoiceType offers way of having fixed set of choices for given column. It + could work with a list of tuple (a collection of key-value pairs), or + integrate with :mod:`enum` in the standard library of Python 3.4+ (the + enum34_ backported package on PyPI is compatible too for ``< 3.4``). + + .. _enum34: https://pypi.python.org/pypi/enum34 + + Columns with ChoiceTypes are automatically coerced to Choice objects while + a list of tuple been passed to the constructor. If a subclass of + :class:`enum.Enum` is passed, columns will be coerced to :class:`enum.Enum` + objects instead. + + :: + + class User(Base): + TYPES = [ + (u'admin', u'Admin'), + (u'regular-user', u'Regular user') + ] + + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + type = sa.Column(ChoiceType(TYPES)) + + + user = User(type=u'admin') + user.type # Choice(type='admin', value=u'Admin') + + Or:: + + import enum + + + class UserType(enum.Enum): + admin = 1 + regular = 2 + + + class User(Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + type = sa.Column(ChoiceType(UserType, impl=sa.Integer())) + + + user = User(type=1) + user.type # + + + ChoiceType is very useful when the rendered values change based on user's + locale: + + :: + + from babel import lazy_gettext as _ + + + class User(Base): + TYPES = [ + (u'admin', _(u'Admin')), + (u'regular-user', _(u'Regular user')) + ] + + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + type = sa.Column(ChoiceType(TYPES)) + + + user = User(type=u'admin') + user.type # Choice(type='admin', value=u'Admin') + + print user.type # u'Admin' + + Or:: + + from enum import Enum + from babel import lazy_gettext as _ + + + class UserType(Enum): + admin = 1 + regular = 2 + + + UserType.admin.label = _(u'Admin') + UserType.regular.label = _(u'Regular user') + + + class User(Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + type = sa.Column(ChoiceType(UserType, impl=sa.Integer())) + + + user = User(type=UserType.admin) + user.type # + + print user.type.label # u'Admin' + """ + + impl = types.Unicode(255) + + def __init__(self, choices, impl=None): + self.choices = choices + + if ( + Enum is not None and + isinstance(choices, type) and + issubclass(choices, Enum) + ): + self.type_impl = EnumTypeImpl(enum_class=choices) + else: + self.type_impl = ChoiceTypeImpl(choices=choices) + + if impl: + self.impl = impl + + @property + def python_type(self): + return self.impl.python_type + + def _coerce(self, value): + return self.type_impl._coerce(value) + + def process_bind_param(self, value, dialect): + return self.type_impl.process_bind_param(value, dialect) + + def process_result_value(self, value, dialect): + return self.type_impl.process_result_value(value, dialect) + + +class ChoiceTypeImpl(object): + """The implementation for the ``Choice`` usage.""" + + def __init__(self, choices): + if not choices: + raise ImproperlyConfigured( + 'ChoiceType needs list of choices defined.' + ) + self.choices_dict = dict(choices) + + def _coerce(self, value): + if value is None: + return value + if isinstance(value, Choice): + return value + return Choice(value, self.choices_dict[value]) + + def process_bind_param(self, value, dialect): + if value and isinstance(value, Choice): + return value.code + return value + + def process_result_value(self, value, dialect): + if value: + return Choice(value, self.choices_dict[value]) + return value + + +class EnumTypeImpl(object): + """The implementation for the ``Enum`` usage.""" + + def __init__(self, enum_class): + if Enum is None: + raise ImproperlyConfigured( + "'enum34' package is required to use 'EnumType' in Python " + "< 3.4" + ) + if not issubclass(enum_class, Enum): + raise ImproperlyConfigured( + "EnumType needs a class of enum defined." + ) + + self.enum_class = enum_class + + def _coerce(self, value): + if value is None: + return None + return self.enum_class(value) + + def process_bind_param(self, value, dialect): + if value is None: + return None + return self.enum_class(value).value + + def process_result_value(self, value, dialect): + return self._coerce(value) diff --git a/python-sqlalchemy-utils/sqlalchemy_utils/types/color.py b/python-sqlalchemy-utils/sqlalchemy_utils/types/color.py new file mode 100644 index 0000000..7020a8f --- /dev/null +++ b/python-sqlalchemy-utils/sqlalchemy_utils/types/color.py @@ -0,0 +1,80 @@ +import six +from sqlalchemy import types + +from sqlalchemy_utils.exceptions import ImproperlyConfigured + +from .scalar_coercible import ScalarCoercible + +colour = None +try: + import colour + python_colour_type = colour.Color +except ImportError: + python_colour_type = None + + +class ColorType(types.TypeDecorator, ScalarCoercible): + """ + ColorType provides a way for saving Color (from colour_ package) objects + into database. ColorType saves Color objects as strings on the way in and + converts them back to objects when querying the database. + + :: + + + from colour import Color + from sqlalchemy_utils import ColorType + + + class Document(Base): + __tablename__ = 'document' + id = sa.Column(sa.Integer, autoincrement=True) + name = sa.Column(sa.Unicode(50)) + background_color = sa.Column(ColorType) + + + document = Document() + document.background_color = Color('#F5F5F5') + session.commit() + + + Querying the database returns Color objects: + + :: + + document = session.query(Document).first() + + document.background_color.hex + # '#f5f5f5' + + + .. _colour: https://github.com/vaab/colour + """ + STORE_FORMAT = u'hex' + impl = types.Unicode(20) + python_type = python_colour_type + + def __init__(self, max_length=20, *args, **kwargs): + # Fail if colour is not found. + if colour is None: + raise ImproperlyConfigured( + "'colour' package is required to use 'ColorType'" + ) + + super(ColorType, self).__init__(*args, **kwargs) + self.impl = types.Unicode(max_length) + + def process_bind_param(self, value, dialect): + if value and isinstance(value, colour.Color): + return six.text_type(getattr(value, self.STORE_FORMAT)) + return value + + def process_result_value(self, value, dialect): + if value: + return colour.Color(value) + return value + + def _coerce(self, value): + if value is not None and not isinstance(value, colour.Color): + return colour.Color(value) + return value diff --git a/python-sqlalchemy-utils/sqlalchemy_utils/types/country.py b/python-sqlalchemy-utils/sqlalchemy_utils/types/country.py new file mode 100644 index 0000000..dd05fa1 --- /dev/null +++ b/python-sqlalchemy-utils/sqlalchemy_utils/types/country.py @@ -0,0 +1,65 @@ +import six +from sqlalchemy import types + +from sqlalchemy_utils.primitives import Country + +from .scalar_coercible import ScalarCoercible + + +class CountryType(types.TypeDecorator, ScalarCoercible): + """ + Changes :class:`.Country` objects to a string representation on the way in + and changes them back to :class:`.Country objects on the way out. + + In order to use CountryType you need to install Babel_ first. + + .. _Babel: http://babel.pocoo.org/ + + :: + + + from sqlalchemy_utils import CountryType, Country + + + class User(Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, autoincrement=True) + name = sa.Column(sa.Unicode(255)) + country = sa.Column(CountryType) + + + user = User() + user.country = Country('FI') + session.add(user) + session.commit() + + user.country # Country('FI') + user.country.name # Finland + + print user.country # Finland + + + CountryType is scalar coercible:: + + + user.country = 'US' + user.country # Country('US') + """ + impl = types.String(2) + python_type = Country + + def process_bind_param(self, value, dialect): + if isinstance(value, Country): + return value.code + + if isinstance(value, six.string_types): + return value + + def process_result_value(self, value, dialect): + if value is not None: + return Country(value) + + def _coerce(self, value): + if value is not None and not isinstance(value, Country): + return Country(value) + return value diff --git a/python-sqlalchemy-utils/sqlalchemy_utils/types/currency.py b/python-sqlalchemy-utils/sqlalchemy_utils/types/currency.py new file mode 100644 index 0000000..6c8abbe --- /dev/null +++ b/python-sqlalchemy-utils/sqlalchemy_utils/types/currency.py @@ -0,0 +1,75 @@ +import six +from sqlalchemy import types + +from sqlalchemy_utils import i18n, ImproperlyConfigured +from sqlalchemy_utils.primitives import Currency + +from .scalar_coercible import ScalarCoercible + + +class CurrencyType(types.TypeDecorator, ScalarCoercible): + """ + Changes :class:`.Currency` objects to a string representation on the way in + and changes them back to :class:`.Currency` objects on the way out. + + In order to use CurrencyType you need to install Babel_ first. + + .. _Babel: http://babel.pocoo.org/ + + :: + + + from sqlalchemy_utils import CurrencyType, Currency + + + class User(Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, autoincrement=True) + name = sa.Column(sa.Unicode(255)) + currency = sa.Column(CurrencyType) + + + user = User() + user.currency = Currency('USD') + session.add(user) + session.commit() + + user.currency # Currency('USD') + user.currency.name # US Dollar + + str(user.currency) # US Dollar + user.currency.symbol # $ + + + + CurrencyType is scalar coercible:: + + + user.currency = 'US' + user.currency # Currency('US') + """ + impl = types.String(3) + python_type = Currency + + def __init__(self, *args, **kwargs): + if i18n.babel is None: + raise ImproperlyConfigured( + "'babel' package is required in order to use CurrencyType." + ) + + super(CurrencyType, self).__init__(*args, **kwargs) + + def process_bind_param(self, value, dialect): + if isinstance(value, Currency): + return value.code + elif isinstance(value, six.string_types): + return value + + def process_result_value(self, value, dialect): + if value is not None: + return Currency(value) + + def _coerce(self, value): + if value is not None and not isinstance(value, Currency): + return Currency(value) + return value diff --git a/python-sqlalchemy-utils/sqlalchemy_utils/types/email.py b/python-sqlalchemy-utils/sqlalchemy_utils/types/email.py new file mode 100644 index 0000000..377d92c --- /dev/null +++ b/python-sqlalchemy-utils/sqlalchemy_utils/types/email.py @@ -0,0 +1,17 @@ +import sqlalchemy as sa + +from ..operators import CaseInsensitiveComparator + + +class EmailType(sa.types.TypeDecorator): + impl = sa.Unicode(255) + comparator_factory = CaseInsensitiveComparator + + def process_bind_param(self, value, dialect): + if value is not None: + return value.lower() + return value + + @property + def python_type(self): + return self.impl.type.python_type diff --git a/python-sqlalchemy-utils/sqlalchemy_utils/types/encrypted.py b/python-sqlalchemy-utils/sqlalchemy_utils/types/encrypted.py new file mode 100644 index 0000000..4470334 --- /dev/null +++ b/python-sqlalchemy-utils/sqlalchemy_utils/types/encrypted.py @@ -0,0 +1,299 @@ +# -*- coding: utf-8 -*- +import base64 +import datetime + +import six +from sqlalchemy.types import Binary, String, TypeDecorator + +from sqlalchemy_utils.exceptions import ImproperlyConfigured + +from .scalar_coercible import ScalarCoercible + +cryptography = None +try: + import cryptography + from cryptography.hazmat.backends import default_backend + from cryptography.hazmat.primitives import hashes + from cryptography.hazmat.primitives.ciphers import( + Cipher, algorithms, modes + ) + from cryptography.fernet import Fernet +except ImportError: + pass + + +class EncryptionDecryptionBaseEngine(object): + """A base encryption and decryption engine. + + This class must be sub-classed in order to create + new engines. + """ + + def _update_key(self, key): + if isinstance(key, six.string_types): + key = key.encode() + digest = hashes.Hash(hashes.SHA256(), backend=default_backend()) + digest.update(key) + engine_key = digest.finalize() + + self._initialize_engine(engine_key) + + def encrypt(self, value): + raise NotImplementedError('Subclasses must implement this!') + + def decrypt(self, value): + raise NotImplementedError('Subclasses must implement this!') + + +class AesEngine(EncryptionDecryptionBaseEngine): + """Provide AES encryption and decryption methods.""" + + BLOCK_SIZE = 16 + PADDING = six.b('*') + + def _initialize_engine(self, parent_class_key): + self.secret_key = parent_class_key + self.iv = self.secret_key[:16] + self.cipher = Cipher( + algorithms.AES(self.secret_key), + modes.CBC(self.iv), + backend=default_backend() + ) + + def _pad(self, value): + """Pad the message to be encrypted, if needed.""" + BS = self.BLOCK_SIZE + P = self.PADDING + padded = (value + (BS - len(value) % BS) * P) + return padded + + def encrypt(self, value): + if not isinstance(value, six.string_types): + value = repr(value) + if isinstance(value, six.text_type): + value = str(value) + value = value.encode() + value = self._pad(value) + encryptor = self.cipher.encryptor() + encrypted = encryptor.update(value) + encryptor.finalize() + encrypted = base64.b64encode(encrypted) + return encrypted + + def decrypt(self, value): + if isinstance(value, six.text_type): + value = str(value) + decryptor = self.cipher.decryptor() + decrypted = base64.b64decode(value) + decrypted = decryptor.update(decrypted)+decryptor.finalize() + decrypted = decrypted.rstrip(self.PADDING) + if not isinstance(decrypted, six.string_types): + decrypted = decrypted.decode('utf-8') + return decrypted + + +class FernetEngine(EncryptionDecryptionBaseEngine): + """Provide Fernet encryption and decryption methods.""" + + def _initialize_engine(self, parent_class_key): + self.secret_key = base64.urlsafe_b64encode(parent_class_key) + self.fernet = Fernet(self.secret_key) + + def encrypt(self, value): + if not isinstance(value, six.string_types): + value = repr(value) + if isinstance(value, six.text_type): + value = str(value) + value = value.encode() + encrypted = self.fernet.encrypt(value) + return encrypted + + def decrypt(self, value): + if isinstance(value, six.text_type): + value = str(value) + decrypted = self.fernet.decrypt(value) + if not isinstance(decrypted, six.string_types): + decrypted = decrypted.decode('utf-8') + return decrypted + + +class EncryptedType(TypeDecorator, ScalarCoercible): + """ + EncryptedType provides a way to encrypt and decrypt values, + to and from databases, that their type is a basic SQLAlchemy type. + For example Unicode, String or even Boolean. + On the way in, the value is encrypted and on the way out the stored value + is decrypted. + + EncryptedType needs Cryptography_ library in order to work. + A simple example is given below. + + .. _Cryptography: https://cryptography.io/en/latest/ + + :: + + import sqlalchemy as sa + from sqlalchemy.ext.declarative import declarative_base + from sqlalchemy import create_engine + from sqlalchemy.orm import sessionmaker + from sqlalchemy_utils import EncryptedType + + + secret_key = 'secretkey1234' + # setup + engine = create_engine('sqlite:///:memory:') + connection = engine.connect() + Base = declarative_base() + + class User(Base): + __tablename__ = "user" + id = sa.Column(sa.Integer, primary_key=True) + username = sa.Column(EncryptedType(sa.Unicode, secret_key)) + access_token = sa.Column(EncryptedType(sa.String, secret_key)) + is_active = sa.Column(EncryptedType(sa.Boolean, secret_key)) + number_of_accounts = sa.Column(EncryptedType(sa.Integer, + secret_key)) + + sa.orm.configure_mappers() + Base.metadata.create_all(connection) + + # create a configured "Session" class + Session = sessionmaker(bind=connection) + + # create a Session + session = Session() + + # example + user_name = u'secret_user' + test_token = 'atesttoken' + active = True + num_of_accounts = 2 + + user = User(username=user_name, access_token=test_token, + is_active=active, accounts_num=accounts) + session.add(user) + session.commit() + + print('id: {}'.format(user.id)) + print('username: {}'.format(user.username)) + print('token: {}'.format(user.access_token)) + print('active: {}'.format(user.is_active)) + print('accounts: {}'.format(user.accounts_num)) + + # teardown + session.close_all() + Base.metadata.drop_all(connection) + connection.close() + engine.dispose() + + The key parameter accepts a callable to allow for the key to change + per-row instead of be fixed for the whole table. + + :: + def get_key(): + return 'dynamic-key' + + class User(Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + username = sa.Column(EncryptedType( + sa.Unicode, get_key)) + + """ + + impl = Binary + + def __init__(self, type_in=None, key=None, engine=None, **kwargs): + """Initialization.""" + if not cryptography: + raise ImproperlyConfigured( + "'cryptography' is required to use EncryptedType" + ) + super(EncryptedType, self).__init__(**kwargs) + # set the underlying type + if type_in is None: + type_in = String() + elif isinstance(type_in, type): + type_in = type_in() + self.underlying_type = type_in + self._key = key + if not engine: + engine = AesEngine + self.engine = engine() + + @property + def key(self): + return self._key + + @key.setter + def key(self, value): + self._key = value + + def _update_key(self): + key = self._key() if callable(self._key) else self._key + self.engine._update_key(key) + + def process_bind_param(self, value, dialect): + """Encrypt a value on the way in.""" + if value is not None: + self._update_key() + + try: + value = self.underlying_type.process_bind_param( + value, dialect + ) + + except AttributeError: + # Doesn't have 'process_bind_param' + + # Handle 'boolean' and 'dates' + type_ = self.underlying_type.python_type + if issubclass(type_, bool): + value = 'true' if value else 'false' + + elif issubclass(type_, (datetime.date, datetime.time)): + value = value.isoformat() + + return self.engine.encrypt(value) + + def process_result_value(self, value, dialect): + """Decrypt value on the way out.""" + if value is not None: + self._update_key() + decrypted_value = self.engine.decrypt(value) + + try: + return self.underlying_type.process_result_value( + decrypted_value, dialect + ) + + except AttributeError: + # Doesn't have 'process_result_value' + + # Handle 'boolean' and 'dates' + type_ = self.underlying_type.python_type + if issubclass(type_, bool): + return decrypted_value == 'true' + + elif issubclass(type_, datetime.datetime): + return datetime.datetime.strptime( + decrypted_value, '%Y-%m-%dT%H:%M:%S' + ) + + elif issubclass(type_, datetime.time): + return datetime.datetime.strptime( + decrypted_value, '%H:%M:%S' + ).time() + + elif issubclass(type_, datetime.date): + return datetime.datetime.strptime( + decrypted_value, '%Y-%m-%d' + ).date() + + # Handle all others + return self.underlying_type.python_type(decrypted_value) + + def _coerce(self, value): + if isinstance(self.underlying_type, ScalarCoercible): + return self.underlying_type._coerce(value) + + return value diff --git a/python-sqlalchemy-utils/sqlalchemy_utils/types/ip_address.py b/python-sqlalchemy-utils/sqlalchemy_utils/types/ip_address.py new file mode 100644 index 0000000..7ec9741 --- /dev/null +++ b/python-sqlalchemy-utils/sqlalchemy_utils/types/ip_address.py @@ -0,0 +1,73 @@ +import six +from sqlalchemy import types + +from sqlalchemy_utils.exceptions import ImproperlyConfigured + +from .scalar_coercible import ScalarCoercible + +ip_address = None +try: + from ipaddress import ip_address +except ImportError: + try: + from ipaddr import IPAddress as ip_address + except ImportError: + pass + + +class IPAddressType(types.TypeDecorator, ScalarCoercible): + """ + Changes IPAddress objects to a string representation on the way in and + changes them back to IPAddress objects on the way out. + + IPAddressType uses ipaddress package on Python >= 3 and ipaddr_ package on + Python 2. In order to use IPAddressType with python you need to install + ipaddr_ first. + + .. _ipaddr: https://pypi.python.org/pypi/ipaddr + + :: + + + from sqlalchemy_utils import IPAddressType + + + class User(Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, autoincrement=True) + name = sa.Column(sa.Unicode(255)) + ip_address = sa.Column(IPAddressType) + + + user = User() + user.ip_address = '123.123.123.123' + session.add(user) + session.commit() + + user.ip_address # IPAddress object + """ + + impl = types.Unicode(50) + + def __init__(self, max_length=50, *args, **kwargs): + if not ip_address: + raise ImproperlyConfigured( + "'ipaddr' package is required to use 'IPAddressType' " + "in python 2" + ) + + super(IPAddressType, self).__init__(*args, **kwargs) + self.impl = types.Unicode(max_length) + + def process_bind_param(self, value, dialect): + return six.text_type(value) if value else None + + def process_result_value(self, value, dialect): + return ip_address(value) if value else None + + def _coerce(self, value): + return ip_address(value) if value else None + + @property + def python_type(self): + return self.impl.type.python_type diff --git a/python-sqlalchemy-utils/sqlalchemy_utils/types/json.py b/python-sqlalchemy-utils/sqlalchemy_utils/types/json.py new file mode 100644 index 0000000..06b3ca5 --- /dev/null +++ b/python-sqlalchemy-utils/sqlalchemy_utils/types/json.py @@ -0,0 +1,88 @@ +from __future__ import absolute_import + +import six +import sqlalchemy as sa +from sqlalchemy.dialects.postgresql.base import ischema_names + +from ..exceptions import ImproperlyConfigured + +json = None +try: + import anyjson as json +except ImportError: + import json as json + +try: + from sqlalchemy.dialects.postgresql import JSON + has_postgres_json = True +except ImportError: + class PostgresJSONType(sa.types.UserDefinedType): + """ + Text search vector type for postgresql. + """ + def get_col_spec(self): + return 'json' + + ischema_names['json'] = PostgresJSONType + has_postgres_json = False + + +class JSONType(sa.types.TypeDecorator): + """ + JSONType offers way of saving JSON data structures to database. On + PostgreSQL the underlying implementation of this data type is 'json' while + on other databases its simply 'text'. + + :: + + + from sqlalchemy_utils import JSONType + + + class Product(Base): + __tablename__ = 'product' + id = sa.Column(sa.Integer, autoincrement=True) + name = sa.Column(sa.Unicode(50)) + details = sa.Column(JSONType) + + + product = Product() + product.details = { + 'color': 'red', + 'type': 'car', + 'max-speed': '400 mph' + } + session.commit() + """ + impl = sa.UnicodeText + + def __init__(self, *args, **kwargs): + if json is None: + raise ImproperlyConfigured( + 'JSONType needs anyjson package installed.' + ) + super(JSONType, self).__init__(*args, **kwargs) + + def load_dialect_impl(self, dialect): + if dialect.name == 'postgresql': + # Use the native JSON type. + if has_postgres_json: + return dialect.type_descriptor(JSON()) + else: + return dialect.type_descriptor(PostgresJSONType()) + else: + return dialect.type_descriptor(self.impl) + + def process_bind_param(self, value, dialect): + if dialect.name == 'postgresql' and has_postgres_json: + return value + if value is not None: + value = six.text_type(json.dumps(value)) + return value + + def process_result_value(self, value, dialect): + if dialect.name == 'postgresql': + return value + if value is not None: + value = json.loads(value) + return value diff --git a/python-sqlalchemy-utils/sqlalchemy_utils/types/locale.py b/python-sqlalchemy-utils/sqlalchemy_utils/types/locale.py new file mode 100644 index 0000000..fccc7d2 --- /dev/null +++ b/python-sqlalchemy-utils/sqlalchemy_utils/types/locale.py @@ -0,0 +1,75 @@ +import six +from sqlalchemy import types + +from ..exceptions import ImproperlyConfigured +from .scalar_coercible import ScalarCoercible + +babel = None +try: + import babel +except ImportError: + pass + + +class LocaleType(types.TypeDecorator, ScalarCoercible): + """ + LocaleType saves Babel_ Locale objects into database. The Locale objects + are converted to string on the way in and back to object on the way out. + + In order to use LocaleType you need to install Babel_ first. + + .. _Babel: http://babel.pocoo.org/ + + :: + + + from sqlalchemy_utils import LocaleType + from babel import Locale + + + class User(Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, autoincrement=True) + name = sa.Column(sa.Unicode(50)) + locale = sa.Column(LocaleType) + + + user = User() + user.locale = Locale('en_US') + session.add(user) + session.commit() + + + Like many other types this type also supports scalar coercion: + + :: + + + user.locale = 'de_DE' + user.locale # Locale('de_DE') + + """ + + impl = types.Unicode(10) + + def __init__(self): + if babel is None: + raise ImproperlyConfigured( + 'Babel packaged is required with LocaleType.' + ) + + def process_bind_param(self, value, dialect): + if isinstance(value, babel.Locale): + return six.text_type(value) + + if isinstance(value, six.string_types): + return value + + def process_result_value(self, value, dialect): + if value is not None: + return babel.Locale(value) + + def _coerce(self, value): + if value is not None and not isinstance(value, babel.Locale): + return babel.Locale(value) + return value diff --git a/python-sqlalchemy-utils/sqlalchemy_utils/types/password.py b/python-sqlalchemy-utils/sqlalchemy_utils/types/password.py new file mode 100644 index 0000000..2b45bf4 --- /dev/null +++ b/python-sqlalchemy-utils/sqlalchemy_utils/types/password.py @@ -0,0 +1,216 @@ +import weakref + +import six +from sqlalchemy import types +from sqlalchemy.dialects import oracle, postgresql +from sqlalchemy.ext.mutable import Mutable + +from sqlalchemy_utils.exceptions import ImproperlyConfigured + +from .scalar_coercible import ScalarCoercible + +passlib = None +try: + import passlib + from passlib.context import CryptContext +except ImportError: + pass + + +class Password(Mutable, object): + + @classmethod + def coerce(cls, key, value): + if isinstance(value, Password): + return value + + if isinstance(value, (six.string_types, six.binary_type)): + return cls(value, secret=True) + + super(Password, cls).coerce(key, value) + + def __init__(self, value, context=None, secret=False): + # Store the hash (if it is one). + self.hash = value if not secret else None + + # Store the secret if we have one. + self.secret = value if secret else None + + # The hash should be bytes. + if isinstance(self.hash, six.text_type): + self.hash = self.hash.encode('utf8') + + # Save weakref of the password context (if we have one) + self.context = weakref.proxy(context) if context is not None else None + + def __eq__(self, value): + if self.hash is None or value is None: + # Ensure that we don't continue comparison if one of us is None. + return self.hash is value + + if isinstance(value, Password): + # Comparing 2 hashes isn't very useful; but this equality + # method breaks otherwise. + return value.hash == self.hash + + if self.context is None: + # Compare 2 hashes again as we don't know how to validate. + return value == self + + if isinstance(value, (six.string_types, six.binary_type)): + valid, new = self.context.verify_and_update(value, self.hash) + if valid and new: + # New hash was calculated due to various reasons; stored one + # wasn't optimal, etc. + self.hash = new + + # The hash should be bytes. + if isinstance(self.hash, six.string_types): + self.hash = self.hash.encode('utf8') + self.changed() + + return valid + + return False + + def __ne__(self, value): + return not (self == value) + + +class PasswordType(types.TypeDecorator, ScalarCoercible): + """ + PasswordType hashes passwords as they come into the database and allows + verifying them using a pythonic interface. + + All keyword arguments (aside from max_length) are forwarded to the + construction of a `passlib.context.CryptContext` object. + + The following usage will create a password column that will + automatically hash new passwords as `pbkdf2_sha512` but still compare + passwords against pre-existing `md5_crypt` hashes. As passwords are + compared; the password hash in the database will be updated to + be `pbkdf2_sha512`. + + :: + + + class Model(Base): + password = sa.Column(PasswordType( + schemes=[ + 'pbkdf2_sha512', + 'md5_crypt' + ], + + deprecated=['md5_crypt'] + )) + + + Verifying password is as easy as: + + :: + + target = Model() + target.password = 'b' + # '$5$rounds=80000$H.............' + + target.password == 'b' + # True + + """ + + impl = types.VARBINARY(1024) + python_type = Password + + def __init__(self, max_length=None, **kwargs): + # Fail if passlib is not found. + if passlib is None: + raise ImproperlyConfigured( + "'passlib' is required to use 'PasswordType'" + ) + + # Construct the passlib crypt context. + self.context = CryptContext(**kwargs) + + if max_length is None: + max_length = self.calculate_max_length() + + # Set the length to the now-calculated max length. + self.length = max_length + + def calculate_max_length(self): + # Calculate the largest possible encoded password. + # name + rounds + salt + hash + ($ * 4) of largest hash + max_lengths = [1024] + for name in self.context.schemes(): + scheme = getattr(__import__('passlib.hash').hash, name) + length = 4 + len(scheme.name) + length += len(str(getattr(scheme, 'max_rounds', ''))) + length += (getattr(scheme, 'max_salt_size', 0) or 0) + length += getattr( + scheme, + 'encoded_checksum_size', + scheme.checksum_size + ) + max_lengths.append(length) + + # Return the maximum calculated max length. + return max(max_lengths) + + def load_dialect_impl(self, dialect): + if dialect.name == 'postgresql': + # Use a BYTEA type for postgresql. + impl = postgresql.BYTEA(self.length) + return dialect.type_descriptor(impl) + if dialect.name == 'oracle': + # Use a RAW type for oracle. + impl = oracle.RAW(self.length) + return dialect.type_descriptor(impl) + + # Use a VARBINARY for all other dialects. + impl = types.VARBINARY(self.length) + return dialect.type_descriptor(impl) + + def process_bind_param(self, value, dialect): + if isinstance(value, Password): + # If were given a password secret; encrypt it. + if value.secret is not None: + return self.context.encrypt(value.secret).encode('utf8') + + # Value has already been hashed. + return value.hash + + if isinstance(value, six.string_types): + # Assume value has not been hashed. + return self.context.encrypt(value).encode('utf8') + + def process_result_value(self, value, dialect): + if value is not None: + return Password(value, self.context) + + def _coerce(self, value): + + if value is None: + return + + if not isinstance(value, Password): + # Hash the password using the default scheme. + value = self.context.encrypt(value).encode('utf8') + return Password(value, context=self.context) + + else: + # If were given a password object; ensure the context is right. + value.context = weakref.proxy(self.context) + + # If were given a password secret; encrypt it. + if value.secret is not None: + value.hash = self.context.encrypt(value.secret).encode('utf8') + value.secret = None + + return value + + @property + def python_type(self): + return self.impl.type.python_type + + +Password.associate_with(PasswordType) diff --git a/python-sqlalchemy-utils/sqlalchemy_utils/types/pg_composite.py b/python-sqlalchemy-utils/sqlalchemy_utils/types/pg_composite.py new file mode 100644 index 0000000..255e0d8 --- /dev/null +++ b/python-sqlalchemy-utils/sqlalchemy_utils/types/pg_composite.py @@ -0,0 +1,352 @@ +""" +CompositeType provides means to interact with +`PostgreSQL composite types`_. Currently this type features: + +* Easy attribute access to composite type fields +* Supports SQLAlchemy TypeDecorator types +* Ability to include composite types as part of PostgreSQL arrays +* Type creation and dropping + +Installation +^^^^^^^^^^^^ + +CompositeType automatically attaches `before_create` and `after_drop` DDL +listeners. These listeners create and drop the composite type in the +database. This means it works out of the box in your test environment where +you create the tables on each test run. + +When you already have your database set up you should call +:func:`register_composites` after you've set up all models. + +:: + + register_composites(conn) + + + +Usage +^^^^^ + +:: + + from collections import OrderedDict + + import sqlalchemy as sa + from sqlalchemy_utils import Composite, CurrencyType + + + class Account(Base): + __tablename__ = 'account' + id = sa.Column(sa.Integer, primary_key=True) + balance = sa.Column( + CompositeType( + 'money_type', + [ + sa.Column('currency', CurrencyType), + sa.Column('amount', sa.Integer) + ] + ) + ) + + +Accessing fields +^^^^^^^^^^^^^^^^ + +CompositeType provides attribute access to underlying fields. In the following +example we find all accounts with balance amount more than 5000. + + +:: + + session.query(Account).filter(Account.balance.amount > 5000) + + +Arrays of composites +^^^^^^^^^^^^^^^^^^^^ + +:: + + from sqlalchemy_utils import CompositeArray + + + class Account(Base): + __tablename__ = 'account' + id = sa.Column(sa.Integer, primary_key=True) + balances = sa.Column( + CompositeArray( + CompositeType( + 'money_type', + [ + sa.Column('currency', CurrencyType), + sa.Column('amount', sa.Integer) + ] + ) + ) + ) + + +.. _PostgreSQL composite types: + http://www.postgresql.org/docs/devel/static/rowtypes.html + + +Related links: + +http://schinckel.net/2014/09/24/using-postgres-composite-types-in-django/ +""" +from collections import namedtuple + +import sqlalchemy as sa +from sqlalchemy.dialects.postgresql import ARRAY +from sqlalchemy.dialects.postgresql.psycopg2 import PGDialect_psycopg2 +from sqlalchemy.ext.compiler import compiles +from sqlalchemy.schema import _CreateDropBase +from sqlalchemy.sql.expression import FunctionElement +from sqlalchemy.types import ( + SchemaType, + to_instance, + TypeDecorator, + UserDefinedType +) + +from sqlalchemy_utils import ImproperlyConfigured + +psycopg2 = None +CompositeCaster = None +adapt = None +AsIs = None +register_adapter = None +try: + import psycopg2 + from psycopg2.extras import CompositeCaster + from psycopg2.extensions import adapt, AsIs, register_adapter +except ImportError: + pass + + +class CompositeElement(FunctionElement): + """ + Instances of this class wrap a Postgres composite type. + """ + def __init__(self, base, field, type_): + self.name = field + self.type = to_instance(type_) + + super(CompositeElement, self).__init__(base) + + +@compiles(CompositeElement) +def _compile_pgelem(expr, compiler, **kw): + return '(%s).%s' % (compiler.process(expr.clauses, **kw), expr.name) + + +class CompositeArray(ARRAY): + def _proc_array(self, arr, itemproc, dim, collection): + if dim is None: + if isinstance(self.item_type, CompositeType): + arr = [itemproc(a) for a in arr] + return arr + return ARRAY._proc_array(self, arr, itemproc, dim, collection) + + +# TODO: Make the registration work on connection level instead of global level +registered_composites = {} + + +class CompositeType(UserDefinedType, SchemaType): + """ + Represents a PostgreSQL composite type. + + :param name: + Name of the composite type. + :param columns: + List of columns that this composite type consists of + """ + python_type = tuple + + class comparator_factory(UserDefinedType.Comparator): + def __getattr__(self, key): + try: + type_ = self.type.typemap[key] + except KeyError: + raise KeyError( + "Type '%s' doesn't have an attribute: '%s'" % ( + self.name, key + ) + ) + + return CompositeElement(self.expr, key, type_) + + def __init__(self, name, columns): + if psycopg2 is None: + raise ImproperlyConfigured( + "'psycopg2' package is required in order to use CompositeType." + ) + SchemaType.__init__(self) + self.name = name + self.columns = columns + if name in registered_composites: + self.type_cls = registered_composites[name].type_cls + else: + self.type_cls = namedtuple( + self.name, [c.name for c in columns] + ) + registered_composites[name] = self + + class Caster(CompositeCaster): + def make(obj, values): + return self.type_cls(*values) + + self.caster = Caster + attach_composite_listeners() + + def get_col_spec(self): + return self.name + + def bind_processor(self, dialect): + def process(value): + if value is None: + return None + processed_value = [] + for i, column in enumerate(self.columns): + if isinstance(column.type, TypeDecorator): + processed_value.append( + column.type.process_bind_param( + value[i], dialect + ) + ) + else: + processed_value.append(value[i]) + return self.type_cls(*processed_value) + return process + + def result_processor(self, dialect, coltype): + def process(value): + if value is None: + return None + cls = value.__class__ + kwargs = {} + for column in self.columns: + if isinstance(column.type, TypeDecorator): + kwargs[column.name] = column.type.process_result_value( + getattr(value, column.name), dialect + ) + else: + kwargs[column.name] = getattr(value, column.name) + return cls(**kwargs) + return process + + def create(self, bind=None, checkfirst=None): + if ( + not checkfirst or + not bind.dialect.has_type(bind, self.name, schema=self.schema) + ): + bind.execute(CreateCompositeType(self)) + + def drop(self, bind=None, checkfirst=True): + if ( + checkfirst and + bind.dialect.has_type(bind, self.name, schema=self.schema) + ): + bind.execute(DropCompositeType(self)) + + +def register_psycopg2_composite(dbapi_connection, composite): + psycopg2.extras.register_composite( + composite.name, + dbapi_connection, + globally=True, + factory=composite.caster + ) + + def adapt_composite(value): + values = [ + adapt( + getattr(value, column.name) + if not isinstance(column.type, TypeDecorator) + else column.type.process_bind_param( + getattr(value, column.name), + PGDialect_psycopg2() + ) + ).getquoted().decode('utf-8') + for column in + composite.columns + ] + return AsIs("(%s)::%s" % (', '.join(values), composite.name)) + + register_adapter(composite.type_cls, adapt_composite) + + +def before_create(target, connection, **kw): + for name, composite in registered_composites.items(): + composite.create(connection, checkfirst=True) + register_psycopg2_composite( + connection.connection.connection, + composite + ) + + +def after_drop(target, connection, **kw): + for name, composite in registered_composites.items(): + composite.drop(connection, checkfirst=True) + + +def register_composites(connection): + for name, composite in registered_composites.items(): + register_psycopg2_composite( + connection.connection.connection, + composite + ) + + +def attach_composite_listeners(): + listeners = [ + (sa.MetaData, 'before_create', before_create), + (sa.MetaData, 'after_drop', after_drop), + ] + for listener in listeners: + if not sa.event.contains(*listener): + sa.event.listen(*listener) + + +def remove_composite_listeners(): + listeners = [ + (sa.MetaData, 'before_create', before_create), + (sa.MetaData, 'after_drop', after_drop), + ] + for listener in listeners: + if sa.event.contains(*listener): + sa.event.remove(*listener) + + +class CreateCompositeType(_CreateDropBase): + pass + + +@compiles(CreateCompositeType) +def _visit_create_composite_type(create, compiler, **kw): + type_ = create.element + fields = ', '.join( + '{name} {type}'.format( + name=column.name, + type=compiler.dialect.type_compiler.process( + to_instance(column.type) + ) + ) + for column in type_.columns + ) + + return 'CREATE TYPE {name} AS ({fields})'.format( + name=compiler.preparer.format_type(type_), + fields=fields + ) + + +class DropCompositeType(_CreateDropBase): + pass + + +@compiles(DropCompositeType) +def _visit_drop_composite_type(drop, compiler, **kw): + type_ = drop.element + + return 'DROP TYPE {name}'.format(name=compiler.preparer.format_type(type_)) diff --git a/python-sqlalchemy-utils/sqlalchemy_utils/types/phone_number.py b/python-sqlalchemy-utils/sqlalchemy_utils/types/phone_number.py new file mode 100644 index 0000000..07490c0 --- /dev/null +++ b/python-sqlalchemy-utils/sqlalchemy_utils/types/phone_number.py @@ -0,0 +1,131 @@ +from sqlalchemy import types + +from sqlalchemy_utils.exceptions import ImproperlyConfigured +from sqlalchemy_utils.utils import str_coercible + +from .scalar_coercible import ScalarCoercible + +try: + import phonenumbers + from phonenumbers.phonenumber import PhoneNumber as BasePhoneNumber +except ImportError: + phonenumbers = None + BasePhoneNumber = object + + +@str_coercible +class PhoneNumber(BasePhoneNumber): + ''' + Extends a PhoneNumber class from `Python phonenumbers library`_. Adds + different phone number formats to attributes, so they can be easily used + in templates. Phone number validation method is also implemented. + + Takes the raw phone number and country code as params and parses them + into a PhoneNumber object. + + .. _Python phonenumbers library: + https://github.com/daviddrysdale/python-phonenumbers + + :param raw_number: + String representation of the phone number. + :param country_code: + Country code of the phone number. + ''' + def __init__(self, raw_number, country_code=None): + # Bail if phonenumbers is not found. + if phonenumbers is None: + raise ImproperlyConfigured( + "'phonenumbers' is required to use 'PhoneNumber'") + + self._phone_number = phonenumbers.parse(raw_number, country_code) + super(PhoneNumber, self).__init__( + country_code=self._phone_number.country_code, + national_number=self._phone_number.national_number, + extension=self._phone_number.extension, + italian_leading_zero=self._phone_number.italian_leading_zero, + raw_input=self._phone_number.raw_input, + country_code_source=self._phone_number.country_code_source, + preferred_domestic_carrier_code=( + self._phone_number.preferred_domestic_carrier_code + ) + ) + self.national = phonenumbers.format_number( + self._phone_number, + phonenumbers.PhoneNumberFormat.NATIONAL + ) + self.international = phonenumbers.format_number( + self._phone_number, + phonenumbers.PhoneNumberFormat.INTERNATIONAL + ) + self.e164 = phonenumbers.format_number( + self._phone_number, + phonenumbers.PhoneNumberFormat.E164 + ) + + def is_valid_number(self): + return phonenumbers.is_valid_number(self._phone_number) + + def __unicode__(self): + return self.national + + +class PhoneNumberType(types.TypeDecorator, ScalarCoercible): + """ + Changes PhoneNumber objects to a string representation on the way in and + changes them back to PhoneNumber objects on the way out. If E164 is used + as storing format, no country code is needed for parsing the database + value to PhoneNumber object. + + :: + + class User(self.Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, autoincrement=True, primary_key=True) + name = sa.Column(sa.Unicode(255)) + phone_number = sa.Column(PhoneNumberType()) + + + user = User(phone_number='+358401234567') + + user.phone_number.e164 # u'+358401234567' + user.phone_number.international # u'+358 40 1234567' + user.phone_number.national # u'040 1234567' + """ + STORE_FORMAT = 'e164' + impl = types.Unicode(20) + + def python_type(self, text): + return self._coerce(text) + + def __init__(self, country_code='US', max_length=20, *args, **kwargs): + # Bail if phonenumbers is not found. + if phonenumbers is None: + raise ImproperlyConfigured( + "'phonenumbers' is required to use 'PhoneNumberType'") + + super(PhoneNumberType, self).__init__(*args, **kwargs) + self.country_code = country_code + self.impl = types.Unicode(max_length) + + def process_bind_param(self, value, dialect): + if value: + if not isinstance(value, PhoneNumber): + value = PhoneNumber(value, country_code=self.country_code) + + if self.STORE_FORMAT == 'e164' and value.extension: + return '%s;ext=%s' % (value.e164, value.extension) + + return getattr(value, self.STORE_FORMAT) + + return value + + def process_result_value(self, value, dialect): + if value: + return PhoneNumber(value, self.country_code) + return value + + def _coerce(self, value): + if value and not isinstance(value, PhoneNumber): + value = PhoneNumber(value, country_code=self.country_code) + + return value or None diff --git a/python-sqlalchemy-utils/sqlalchemy_utils/types/range.py b/python-sqlalchemy-utils/sqlalchemy_utils/types/range.py new file mode 100644 index 0000000..dc40653 --- /dev/null +++ b/python-sqlalchemy-utils/sqlalchemy_utils/types/range.py @@ -0,0 +1,418 @@ +""" +SQLAlchemy-Utils provides wide variety of range data types. All range data +types return Interval objects of intervals_ package. In order to use range data +types you need to install intervals_ with: + +:: + + pip install intervals + + +Intervals package provides good chunk of additional interval operators that for +example psycopg2 range objects do not support. + + + +Some good reading for practical interval implementations: + +http://wiki.postgresql.org/images/f/f0/Range-types.pdf + + +Range type initialization +------------------------- + +:: + + + + from sqlalchemy_utils import IntRangeType + + + class Event(Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, autoincrement=True) + name = sa.Column(sa.Unicode(255)) + estimated_number_of_persons = sa.Column(IntRangeType) + + + +You can also set a step parameter for range type. The values that are not +multipliers of given step will be rounded up to nearest step multiplier. + + +:: + + + from sqlalchemy_utils import IntRangeType + + + class Event(Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, autoincrement=True) + name = sa.Column(sa.Unicode(255)) + estimated_number_of_persons = sa.Column(IntRangeType(step=1000)) + + + event = Event(estimated_number_of_persons=[100, 1200]) + event.estimated_number_of_persons.lower # 0 + event.estimated_number_of_persons.upper # 1000 + + +Range type operators +-------------------- + +SQLAlchemy-Utils supports many range type operators. These operators follow the +`intervals` package interval coercion rules. + +So for example when we make a query such as: + +:: + + session.query(Car).filter(Car.price_range == 300) + + +It is essentially the same as: + +:: + + session.query(Car).filter(Car.price_range == DecimalInterval([300, 300])) + + +Comparison operators +^^^^^^^^^^^^^^^^^^^^ + +All range types support all comparison operators (>, >=, ==, !=, <=, <). + +:: + + Car.price_range < [12, 300] + + Car.price_range == [12, 300] + + Car.price_range < 300 + + Car.price_range > (300, 500) + + # Whether or not range is strictly left of another range + Car.price_range << [300, 500] + + # Whether or not range is strictly right of another range + Car.price_range >> [300, 500] + + + +Membership operators +^^^^^^^^^^^^^^^^^^^^ + +:: + + Car.price_range.contains([300, 500]) + + Car.price_range.contained_by([300, 500]) + + Car.price_range.in_([[300, 500], [800, 900]]) + + ~ Car.price_range.in_([[300, 400], [700, 800]]) + + +Length +^^^^^^ + +SQLAlchemy-Utils provides length property for all range types. The +implementation of this property varies on different range types. + +In the following example we find all cars whose price range's length is more +than 500. + +:: + + session.query(Car).filter( + Car.price_range.length > 500 + ) + + + +.. _intervals: https://github.com/kvesteri/intervals +""" +from collections import Iterable +from datetime import timedelta + +import six +import sqlalchemy as sa +from sqlalchemy import types +from sqlalchemy.dialects.postgresql import ( + DATERANGE, + INT4RANGE, + NUMRANGE, + TSRANGE +) + +from ..exceptions import ImproperlyConfigured +from .scalar_coercible import ScalarCoercible + +intervals = None +try: + import intervals +except ImportError: + pass + + +class RangeComparator(types.TypeEngine.Comparator): + @classmethod + def coerced_func(cls, func): + def operation(self, other, **kwargs): + other = self.coerce_arg(other) + return getattr(types.TypeEngine.Comparator, func)( + self, other, **kwargs + ) + return operation + + def coerce_arg(self, other): + coerced_types = ( + self.type.interval_class.type, + tuple, + list, + ) + six.string_types + + if isinstance(other, coerced_types): + return self.type.interval_class(other) + return other + + def in_(self, other): + if ( + isinstance(other, Iterable) and + not isinstance(other, six.string_types) + ): + other = map(self.coerce_arg, other) + return super(RangeComparator, self).in_(other) + + def notin_(self, other): + if ( + isinstance(other, Iterable) and + not isinstance(other, six.string_types) + ): + other = map(self.coerce_arg, other) + return super(RangeComparator, self).notin_(other) + + def __rshift__(self, other, **kwargs): + """ + Returns whether or not given interval is strictly right of another + interval. + + [a, b] >> [c, d] True, if a > d + """ + other = self.coerce_arg(other) + return self.op('>>')(other) + + def __lshift__(self, other, **kwargs): + """ + Returns whether or not given interval is strictly left of another + interval. + + [a, b] << [c, d] True, if b < c + """ + other = self.coerce_arg(other) + return self.op('<<')(other) + + def contains(self, other, **kwargs): + other = self.coerce_arg(other) + return self.op('@>')(other) + + def contained_by(self, other, **kwargs): + other = self.coerce_arg(other) + return self.op('<@')(other) + + +class DiscreteRangeComparator(RangeComparator): + @property + def length(self): + return sa.func.upper(self.expr) - self.step - sa.func.lower(self.expr) + + +class IntRangeComparator(DiscreteRangeComparator): + step = 1 + + +class DateRangeComparator(DiscreteRangeComparator): + step = timedelta(days=1) + + +class ContinuousRangeComparator(RangeComparator): + @property + def length(self): + return sa.func.upper(self.expr) - sa.func.lower(self.expr) + + +funcs = [ + '__eq__', + '__ne__', + '__lt__', + '__le__', + '__gt__', + '__ge__', +] + + +for func in funcs: + setattr( + RangeComparator, + func, + RangeComparator.coerced_func(func) + ) + + +class RangeType(types.TypeDecorator, ScalarCoercible): + comparator_factory = RangeComparator + + def __init__(self, *args, **kwargs): + if intervals is None: + raise ImproperlyConfigured( + 'RangeType needs intervals package installed.' + ) + self.step = kwargs.pop('step', None) + super(RangeType, self).__init__(*args, **kwargs) + + def load_dialect_impl(self, dialect): + if dialect.name == 'postgresql': + # Use the native range type for postgres. + return dialect.type_descriptor(self.impl) + else: + # Other drivers don't have native types. + return dialect.type_descriptor(sa.String(255)) + + def process_bind_param(self, value, dialect): + if value is not None: + return str(value) + return value + + def process_result_value(self, value, dialect): + if value is not None: + if self.interval_class.step is not None: + return self.canonicalize_result_value( + self.interval_class(value, step=self.step) + ) + else: + return self.interval_class(value, step=self.step) + return value + + def canonicalize_result_value(self, value): + return intervals.canonicalize(value, True, True) + + def _coerce(self, value): + if value is None: + return None + return self.interval_class(value, step=self.step) + + +class IntRangeType(RangeType): + """ + IntRangeType provides way for saving ranges of integers into database. On + PostgreSQL this type maps to native INT4RANGE type while on other drivers + this maps to simple string column. + + Example:: + + + from sqlalchemy_utils import IntRangeType + + + class Event(Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, autoincrement=True) + name = sa.Column(sa.Unicode(255)) + estimated_number_of_persons = sa.Column(IntRangeType) + + + party = Event(name=u'party') + + # we estimate the party to contain minium of 10 persons and at max + # 100 persons + party.estimated_number_of_persons = [10, 100] + + print party.estimated_number_of_persons + # '10-100' + + + IntRangeType returns the values as IntInterval objects. These objects + support many arithmetic operators:: + + + meeting = Event(name=u'meeting') + + meeting.estimated_number_of_persons = [20, 40] + + total = ( + meeting.estimated_number_of_persons + + party.estimated_number_of_persons + ) + print total + # '30-140' + """ + impl = INT4RANGE + comparator_factory = IntRangeComparator + + def __init__(self, *args, **kwargs): + super(IntRangeType, self).__init__(*args, **kwargs) + self.interval_class = intervals.IntInterval + + +class DateRangeType(RangeType): + """ + DateRangeType provides way for saving ranges of dates into database. On + PostgreSQL this type maps to native DATERANGE type while on other drivers + this maps to simple string column. + + Example:: + + + from sqlalchemy_utils import DateRangeType + + + class Reservation(Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, autoincrement=True) + room_id = sa.Column(sa.Integer)) + during = sa.Column(DateRangeType) + """ + impl = DATERANGE + comparator_factory = DateRangeComparator + + def __init__(self, *args, **kwargs): + super(DateRangeType, self).__init__(*args, **kwargs) + self.interval_class = intervals.DateInterval + + +class NumericRangeType(RangeType): + """ + NumericRangeType provides way for saving ranges of decimals into database. + On PostgreSQL this type maps to native NUMRANGE type while on other drivers + this maps to simple string column. + + Example:: + + + from sqlalchemy_utils import NumericRangeType + + + class Car(Base): + __tablename__ = 'car' + id = sa.Column(sa.Integer, autoincrement=True) + name = sa.Column(sa.Unicode(255))) + price_range = sa.Column(NumericRangeType) + """ + + impl = NUMRANGE + comparator_factory = ContinuousRangeComparator + + def __init__(self, *args, **kwargs): + super(NumericRangeType, self).__init__(*args, **kwargs) + self.interval_class = intervals.DecimalInterval + + +class DateTimeRangeType(RangeType): + impl = TSRANGE + comparator_factory = ContinuousRangeComparator + + def __init__(self, *args, **kwargs): + super(DateTimeRangeType, self).__init__(*args, **kwargs) + self.interval_class = intervals.DateTimeInterval diff --git a/python-sqlalchemy-utils/sqlalchemy_utils/types/scalar_coercible.py b/python-sqlalchemy-utils/sqlalchemy_utils/types/scalar_coercible.py new file mode 100644 index 0000000..ec436cc --- /dev/null +++ b/python-sqlalchemy-utils/sqlalchemy_utils/types/scalar_coercible.py @@ -0,0 +1,6 @@ +class ScalarCoercible(object): + def _coerce(self, value): + raise NotImplemented + + def coercion_listener(self, target, value, oldvalue, initiator): + return self._coerce(value) diff --git a/python-sqlalchemy-utils/sqlalchemy_utils/types/scalar_list.py b/python-sqlalchemy-utils/sqlalchemy_utils/types/scalar_list.py new file mode 100644 index 0000000..4a79b59 --- /dev/null +++ b/python-sqlalchemy-utils/sqlalchemy_utils/types/scalar_list.py @@ -0,0 +1,83 @@ +import six +import sqlalchemy as sa +from sqlalchemy import types + + +class ScalarListException(Exception): + pass + + +class ScalarListType(types.TypeDecorator): + """ + ScalarListType type provides convenient way for saving multiple scalar + values in one column. ScalarListType works like list on python side and + saves the result as comma-separated list in the database (custom separators + can also be used). + + Example :: + + + from sqlalchemy_utils import ScalarListType + + + class User(Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, autoincrement=True) + hobbies = sa.Column(ScalarListType()) + + + user = User() + user.hobbies = [u'football', u'ice_hockey'] + session.commit() + + + You can easily set up integer lists too: + + :: + + + from sqlalchemy_utils import ScalarListType + + + class Player(Base): + __tablename__ = 'player' + id = sa.Column(sa.Integer, autoincrement=True) + points = sa.Column(ScalarListType(int)) + + + player = Player() + player.points = [11, 12, 8, 80] + session.commit() + + + """ + + impl = sa.UnicodeText() + + def __init__(self, coerce_func=six.text_type, separator=u','): + self.separator = six.text_type(separator) + self.coerce_func = coerce_func + + def process_bind_param(self, value, dialect): + # Convert list of values to unicode separator-separated list + # Example: [1, 2, 3, 4] -> u'1, 2, 3, 4' + if value is not None: + if any(self.separator in six.text_type(item) for item in value): + raise ScalarListException( + "List values can't contain string '%s' (its being used as " + "separator. If you wish for scalar list values to contain " + "these strings, use a different separator string.)" + % self.separator + ) + return self.separator.join( + map(six.text_type, value) + ) + + def process_result_value(self, value, dialect): + if value is not None: + if value == u'': + return [] + # coerce each value + return list(map( + self.coerce_func, value.split(self.separator) + )) diff --git a/python-sqlalchemy-utils/sqlalchemy_utils/types/timezone.py b/python-sqlalchemy-utils/sqlalchemy_utils/types/timezone.py new file mode 100644 index 0000000..730260f --- /dev/null +++ b/python-sqlalchemy-utils/sqlalchemy_utils/types/timezone.py @@ -0,0 +1,88 @@ +import six +from sqlalchemy import types + +from sqlalchemy_utils.exceptions import ImproperlyConfigured + +from .scalar_coercible import ScalarCoercible + + +class TimezoneType(types.TypeDecorator, ScalarCoercible): + """ + TimezoneType provides a way for saving timezones (from either the pytz or + the dateutil package) objects into database. TimezoneType saves timezone + objects as strings on the way in and converts them back to objects when + querying the database. + + + :: + + from sqlalchemy_utils import TimezoneType + + class User(Base): + __tablename__ = 'user' + + # Pass backend='pytz' to change it to use pytz (dateutil by + # default) + timezone = sa.Column(TimezoneType(backend='pytz')) + """ + + impl = types.Unicode(50) + + python_type = None + + def __init__(self, backend='dateutil'): + """ + :param backend: Whether to use 'dateutil' or 'pytz' for timezones. + """ + + self.backend = backend + if backend == 'dateutil': + try: + from dateutil.tz import tzfile + from dateutil.zoneinfo import gettz + + self.python_type = tzfile + self._to = gettz + self._from = lambda x: six.text_type(x._filename) + + except ImportError: + raise ImproperlyConfigured( + "'python-dateutil' is required to use the " + "'dateutil' backend for 'TimezoneType'" + ) + + elif backend == 'pytz': + try: + from pytz import tzfile, timezone + + self.python_type = tzfile.DstTzInfo + self._to = timezone + self._from = six.text_type + + except ImportError: + raise ImproperlyConfigured( + "'pytz' is required to use the 'pytz' backend " + "for 'TimezoneType'" + ) + + else: + raise ImproperlyConfigured( + "'pytz' or 'dateutil' are the backends supported for " + "'TimezoneType'" + ) + + def _coerce(self, value): + if value and not isinstance(value, self.python_type): + obj = self._to(value) + if obj is None: + raise ValueError("unknown time zone '%s'" % value) + + return obj + + return value + + def process_bind_param(self, value, dialect): + return self._from(self._coerce(value)) if value else None + + def process_result_value(self, value, dialect): + return self._to(value) if value else None diff --git a/python-sqlalchemy-utils/sqlalchemy_utils/types/ts_vector.py b/python-sqlalchemy-utils/sqlalchemy_utils/types/ts_vector.py new file mode 100644 index 0000000..011bfc9 --- /dev/null +++ b/python-sqlalchemy-utils/sqlalchemy_utils/types/ts_vector.py @@ -0,0 +1,107 @@ +import sqlalchemy as sa +from sqlalchemy.dialects.postgresql import TSVECTOR + + +class TSVectorType(sa.types.TypeDecorator): + """ + .. note:: + + This type is PostgreSQL specific and is not supported by other + dialects. + + Provides additional functionality for SQLAlchemy PostgreSQL dialect's + TSVECTOR_ type. This additional functionality includes: + + * Vector concatenation + * regconfig constructor parameter which is applied to match function if no + postgresql_regconfig parameter is given + * Provides extensible base for extensions such as SQLAlchemy-Searchable_ + + .. _TSVECTOR: + http://docs.sqlalchemy.org/en/latest/dialects/postgresql.html#full-text-search + + .. _SQLAlchemy-Searchable: + https://www.github.com/kvesteri/sqlalchemy-searchable + + :: + + from sqlalchemy_utils import TSVectorType + + + class Article(Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.String(100)) + search_vector = sa.Column(TSVectorType) + + + # Find all articles whose name matches 'finland' + session.query(Article).filter(Article.search_vector.match('finland')) + + + TSVectorType also supports vector concatenation. + + :: + + + class Article(Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.String(100)) + name_vector = sa.Column(TSVectorType) + content = sa.Column(sa.String) + content_vector = sa.Column(TSVectorType) + + # Find all articles whose name or content matches 'finland' + session.query(Article).filter( + (Article.name_vector | Article.content_vector).match('finland') + ) + + You can configure TSVectorType to use a specific regconfig. + :: + + class Article(Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.String(100)) + search_vector = sa.Column( + TSVectorType(regconfig='pg_catalog.simple') + ) + + + Now expression such as:: + + + Article.search_vector.match('finland') + + + Would be equivalent to SQL:: + + + search_vector @@ to_tsquery('pg_catalog.simgle', 'finland') + + """ + impl = TSVECTOR + + class comparator_factory(TSVECTOR.Comparator): + def match(self, other, **kwargs): + if 'postgresql_regconfig' not in kwargs: + if 'regconfig' in self.type.options: + kwargs['postgresql_regconfig'] = ( + self.type.options['regconfig'] + ) + return TSVECTOR.Comparator.match(self, other, **kwargs) + + def __or__(self, other): + return self.op('||')(other) + + def __init__(self, *args, **kwargs): + """ + Initializes new TSVectorType + + :param *args: list of column names + :param **kwargs: various other options for this TSVectorType + """ + self.columns = args + self.options = kwargs + super(TSVectorType, self).__init__() diff --git a/python-sqlalchemy-utils/sqlalchemy_utils/types/url.py b/python-sqlalchemy-utils/sqlalchemy_utils/types/url.py new file mode 100644 index 0000000..7dd4f84 --- /dev/null +++ b/python-sqlalchemy-utils/sqlalchemy_utils/types/url.py @@ -0,0 +1,67 @@ +furl = None +try: + from furl import furl +except ImportError: + pass +import six +from sqlalchemy import types + +from .scalar_coercible import ScalarCoercible + + +class URLType(types.TypeDecorator, ScalarCoercible): + """ + URLType stores furl_ objects into database. + + .. _furl: https://github.com/gruns/furl + + :: + + from sqlalchemy_utils import URLType + from furl import furl + + + class User(Base): + __tablename__ = 'user' + + id = sa.Column(sa.Integer, primary_key=True) + website = sa.Column(URLType) + + + user = User(website=u'www.example.com') + + # website is coerced to furl object, hence all nice furl operations + # come available + user.website.args['some_argument'] = '12' + + print user.website + # www.example.com?some_argument=12 + """ + + impl = types.UnicodeText + + def process_bind_param(self, value, dialect): + if furl is not None and isinstance(value, furl): + return six.text_type(value) + + if isinstance(value, six.string_types): + return value + + def process_result_value(self, value, dialect): + if furl is None: + return value + + if value is not None: + return furl(value) + + def _coerce(self, value): + if furl is None: + return value + + if value is not None and not isinstance(value, furl): + return furl(value) + return value + + @property + def python_type(self): + return self.impl.type.python_type diff --git a/python-sqlalchemy-utils/sqlalchemy_utils/types/uuid.py b/python-sqlalchemy-utils/sqlalchemy_utils/types/uuid.py new file mode 100644 index 0000000..2fc3ad1 --- /dev/null +++ b/python-sqlalchemy-utils/sqlalchemy_utils/types/uuid.py @@ -0,0 +1,78 @@ +from __future__ import absolute_import + +import uuid + +from sqlalchemy import types +from sqlalchemy.dialects import postgresql + +from .scalar_coercible import ScalarCoercible + + +class UUIDType(types.TypeDecorator, ScalarCoercible): + """ + Stores a UUID in the database natively when it can and falls back to + a BINARY(16) or a CHAR(32) when it can't. + + :: + + from sqlalchemy_utils import UUIDType + import uuid + + class User(Base): + __tablename__ = 'user' + + # Pass `binary=False` to fallback to CHAR instead of BINARY + id = sa.Column(UUIDType(binary=False), primary_key=True) + """ + impl = types.BINARY(16) + + python_type = uuid.UUID + + def __init__(self, binary=True, native=True): + """ + :param binary: Whether to use a BINARY(16) or CHAR(32) fallback. + """ + self.binary = binary + self.native = native + + def load_dialect_impl(self, dialect): + if dialect.name == 'postgresql' and self.native: + # Use the native UUID type. + return dialect.type_descriptor(postgresql.UUID()) + + else: + # Fallback to either a BINARY or a CHAR. + kind = self.impl if self.binary else types.CHAR(32) + return dialect.type_descriptor(kind) + + @staticmethod + def _coerce(value): + if value and not isinstance(value, uuid.UUID): + try: + value = uuid.UUID(value) + + except (TypeError, ValueError): + value = uuid.UUID(bytes=value) + + return value + + def process_bind_param(self, value, dialect): + if value is None: + return value + + if not isinstance(value, uuid.UUID): + value = self._coerce(value) + + if self.native and dialect.name == 'postgresql': + return str(value) + + return value.bytes if self.binary else value.hex + + def process_result_value(self, value, dialect): + if value is None: + return value + + if self.native and dialect.name == 'postgresql': + return uuid.UUID(value) + + return uuid.UUID(bytes=value) if self.binary else uuid.UUID(value) diff --git a/python-sqlalchemy-utils/sqlalchemy_utils/types/weekdays.py b/python-sqlalchemy-utils/sqlalchemy_utils/types/weekdays.py new file mode 100644 index 0000000..e3c3e95 --- /dev/null +++ b/python-sqlalchemy-utils/sqlalchemy_utils/types/weekdays.py @@ -0,0 +1,81 @@ +import six +from sqlalchemy import types + +from sqlalchemy_utils import i18n +from sqlalchemy_utils.exceptions import ImproperlyConfigured +from sqlalchemy_utils.primitives import WeekDay, WeekDays + +from .bit import BitType +from .scalar_coercible import ScalarCoercible + + +class WeekDaysType(types.TypeDecorator, ScalarCoercible): + """ + WeekDaysType offers way of saving WeekDays objects into database. The + WeekDays objects are converted to bit strings on the way in and back to + WeekDays objects on the way out. + + In order to use WeekDaysType you need to install Babel_ first. + + .. _Babel: http://babel.pocoo.org/ + + :: + + + from sqlalchemy_utils import WeekDaysType, WeekDays + from babel import Locale + + + class Schedule(Base): + __tablename__ = 'schedule' + id = sa.Column(sa.Integer, autoincrement=True) + working_days = sa.Column(WeekDaysType) + + + schedule = Schedule() + schedule.working_days = WeekDays('0001111') + session.add(schedule) + session.commit() + + print schedule.working_days # Thursday, Friday, Saturday, Sunday + + + WeekDaysType also supports scalar coercion: + + :: + + + schedule.working_days = '1110000' + schedule.working_days # WeekDays object + + """ + + impl = BitType(WeekDay.NUM_WEEK_DAYS) + + def __init__(self, *args, **kwargs): + if i18n.babel is None: + raise ImproperlyConfigured( + "'babel' package is required to use 'WeekDaysType'" + ) + + super(WeekDaysType, self).__init__(*args, **kwargs) + + @property + def comparator_factory(self): + return self.impl.comparator_factory + + def process_bind_param(self, value, dialect): + if isinstance(value, WeekDays): + return value.as_bit_string() + + if isinstance(value, six.string_types): + return value + + def process_result_value(self, value, dialect): + if value is not None: + return WeekDays(value) + + def _coerce(self, value): + if value is not None and not isinstance(value, WeekDays): + return WeekDays(value) + return value diff --git a/python-sqlalchemy-utils/sqlalchemy_utils/utils.py b/python-sqlalchemy-utils/sqlalchemy_utils/utils.py new file mode 100644 index 0000000..8c5b6b4 --- /dev/null +++ b/python-sqlalchemy-utils/sqlalchemy_utils/utils.py @@ -0,0 +1,22 @@ +import sys +from collections import Iterable + +import six + + +def str_coercible(cls): + if sys.version_info[0] >= 3: # Python 3 + def __str__(self): + return self.__unicode__() + else: # Python 2 + def __str__(self): + return self.__unicode__().encode('utf8') + + cls.__str__ = __str__ + return cls + + +def is_sequence(value): + return ( + isinstance(value, Iterable) and not isinstance(value, six.string_types) + ) diff --git a/python-sqlalchemy-utils/tests/__init__.py b/python-sqlalchemy-utils/tests/__init__.py new file mode 100644 index 0000000..7b32bd5 --- /dev/null +++ b/python-sqlalchemy-utils/tests/__init__.py @@ -0,0 +1,124 @@ +import warnings + +import sqlalchemy as sa +from sqlalchemy import create_engine +from sqlalchemy.ext.declarative import declarative_base, synonym_for +from sqlalchemy.ext.hybrid import hybrid_property +from sqlalchemy.orm import sessionmaker + +from sqlalchemy_utils import ( + aggregates, + coercion_listener, + i18n, + InstrumentedList +) +from sqlalchemy_utils.types.pg_composite import remove_composite_listeners + + +@sa.event.listens_for(sa.engine.Engine, 'before_cursor_execute') +def count_sql_calls(conn, cursor, statement, parameters, context, executemany): + try: + conn.query_count += 1 + except AttributeError: + conn.query_count = 0 + + +warnings.simplefilter('error', sa.exc.SAWarning) + + +sa.event.listen(sa.orm.mapper, 'mapper_configured', coercion_listener) + + +def get_locale(): + class Locale(): + territories = {'FI': 'Finland'} + + return Locale() + + +class TestCase(object): + dns = 'sqlite:///:memory:' + create_tables = True + + def setup_method(self, method): + self.engine = create_engine(self.dns) + # self.engine.echo = True + self.connection = self.engine.connect() + self.Base = declarative_base() + + self.create_models() + sa.orm.configure_mappers() + if self.create_tables: + self.Base.metadata.create_all(self.connection) + + Session = sessionmaker(bind=self.connection) + self.session = Session() + + i18n.get_locale = get_locale + + def teardown_method(self, method): + aggregates.manager.reset() + self.session.close_all() + if self.create_tables: + self.Base.metadata.drop_all(self.connection) + remove_composite_listeners() + self.connection.close() + self.engine.dispose() + + def create_models(self): + class User(self.Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, autoincrement=True, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + class Category(self.Base): + __tablename__ = 'category' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + @hybrid_property + def articles_count(self): + return len(self.articles) + + @articles_count.expression + def articles_count(cls): + return ( + sa.select([sa.func.count(self.Article.id)]) + .where(self.Article.category_id == self.Category.id) + .correlate(self.Article.__table__) + .label('article_count') + ) + + @property + def name_alias(self): + return self.name + + @synonym_for('name') + @property + def name_synonym(self): + return self.name + + class Article(self.Base): + __tablename__ = 'article' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255), index=True) + category_id = sa.Column(sa.Integer, sa.ForeignKey(Category.id)) + + category = sa.orm.relationship( + Category, + primaryjoin=category_id == Category.id, + backref=sa.orm.backref( + 'articles', + collection_class=InstrumentedList + ) + ) + + self.User = User + self.Category = Category + self.Article = Article + + +def assert_contains(clause, query): + # Test that query executes + query.all() + assert clause in str(query) diff --git a/python-sqlalchemy-utils/tests/aggregate/__init__.py b/python-sqlalchemy-utils/tests/aggregate/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/python-sqlalchemy-utils/tests/aggregate/test_backrefs.py b/python-sqlalchemy-utils/tests/aggregate/test_backrefs.py new file mode 100644 index 0000000..e7d0d2d --- /dev/null +++ b/python-sqlalchemy-utils/tests/aggregate/test_backrefs.py @@ -0,0 +1,61 @@ +import sqlalchemy as sa + +from sqlalchemy_utils.aggregates import aggregated +from tests import TestCase + + +class TestAggregateValueGenerationWithBackrefs(TestCase): + def create_models(self): + class Thread(self.Base): + __tablename__ = 'thread' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + @aggregated('comments', sa.Column(sa.Integer, default=0)) + def comment_count(self): + return sa.func.count('1') + + class Comment(self.Base): + __tablename__ = 'comment' + id = sa.Column(sa.Integer, primary_key=True) + content = sa.Column(sa.Unicode(255)) + thread_id = sa.Column(sa.Integer, sa.ForeignKey('thread.id')) + + thread = sa.orm.relationship(Thread, backref='comments') + + self.Thread = Thread + self.Comment = Comment + + def test_assigns_aggregates_on_insert(self): + thread = self.Thread() + thread.name = u'some article name' + self.session.add(thread) + comment = self.Comment(content=u'Some content', thread=thread) + self.session.add(comment) + self.session.commit() + self.session.refresh(thread) + assert thread.comment_count == 1 + + def test_assigns_aggregates_on_separate_insert(self): + thread = self.Thread() + thread.name = u'some article name' + self.session.add(thread) + self.session.commit() + comment = self.Comment(content=u'Some content', thread=thread) + self.session.add(comment) + self.session.commit() + self.session.refresh(thread) + assert thread.comment_count == 1 + + def test_assigns_aggregates_on_delete(self): + thread = self.Thread() + thread.name = u'some article name' + self.session.add(thread) + self.session.commit() + comment = self.Comment(content=u'Some content', thread=thread) + self.session.add(comment) + self.session.commit() + self.session.delete(comment) + self.session.commit() + self.session.refresh(thread) + assert thread.comment_count == 0 diff --git a/python-sqlalchemy-utils/tests/aggregate/test_custom_select_expressions.py b/python-sqlalchemy-utils/tests/aggregate/test_custom_select_expressions.py new file mode 100644 index 0000000..de03b79 --- /dev/null +++ b/python-sqlalchemy-utils/tests/aggregate/test_custom_select_expressions.py @@ -0,0 +1,67 @@ +from decimal import Decimal + +import sqlalchemy as sa + +from sqlalchemy_utils.aggregates import aggregated +from tests import TestCase + + +class TestLazyEvaluatedSelectExpressionsForAggregates(TestCase): + dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' + + def create_models(self): + class Catalog(self.Base): + __tablename__ = 'catalog' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + @aggregated('products', sa.Column(sa.Numeric, default=0)) + def net_worth(self): + return sa.func.sum(Product.price) + + products = sa.orm.relationship('Product', backref='catalog') + + class Product(self.Base): + __tablename__ = 'product' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + price = sa.Column(sa.Numeric) + + catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id')) + + self.Catalog = Catalog + self.Product = Product + + def test_assigns_aggregates_on_insert(self): + catalog = self.Catalog( + name=u'Some catalog' + ) + self.session.add(catalog) + self.session.commit() + product = self.Product( + name=u'Some product', + price=Decimal('1000'), + catalog=catalog + ) + self.session.add(product) + self.session.commit() + self.session.refresh(catalog) + assert catalog.net_worth == Decimal('1000') + + def test_assigns_aggregates_on_update(self): + catalog = self.Catalog( + name=u'Some catalog' + ) + self.session.add(catalog) + self.session.commit() + product = self.Product( + name=u'Some product', + price=Decimal('1000'), + catalog=catalog + ) + self.session.add(product) + self.session.commit() + product.price = Decimal('500') + self.session.commit() + self.session.refresh(catalog) + assert catalog.net_worth == Decimal('500') diff --git a/python-sqlalchemy-utils/tests/aggregate/test_join_table_inheritance.py b/python-sqlalchemy-utils/tests/aggregate/test_join_table_inheritance.py new file mode 100644 index 0000000..fa891ec --- /dev/null +++ b/python-sqlalchemy-utils/tests/aggregate/test_join_table_inheritance.py @@ -0,0 +1,101 @@ +from decimal import Decimal + +import sqlalchemy as sa + +from sqlalchemy_utils.aggregates import aggregated +from tests import TestCase + + +class TestLazyEvaluatedSelectExpressionsForAggregates(TestCase): + dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' + + def create_models(self): + class Catalog(self.Base): + __tablename__ = 'catalog' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + type = sa.Column(sa.Unicode(255)) + + __mapper_args__ = { + 'polymorphic_on': type + } + + @aggregated('products', sa.Column(sa.Numeric, default=0)) + def net_worth(self): + return sa.func.sum(Product.price) + + products = sa.orm.relationship('Product', backref='catalog') + + class CostumeCatalog(Catalog): + __tablename__ = 'costume_catalog' + id = sa.Column( + sa.Integer, sa.ForeignKey(Catalog.id), primary_key=True + ) + + __mapper_args__ = { + 'polymorphic_identity': 'costumes', + } + + class CarCatalog(Catalog): + __tablename__ = 'car_catalog' + id = sa.Column( + sa.Integer, sa.ForeignKey(Catalog.id), primary_key=True + ) + + __mapper_args__ = { + 'polymorphic_identity': 'cars', + } + + class Product(self.Base): + __tablename__ = 'product' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + price = sa.Column(sa.Numeric) + + catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id')) + + self.Catalog = Catalog + self.CostumeCatalog = CostumeCatalog + self.CarCatalog = CarCatalog + self.Product = Product + + def test_columns_inherited_from_parent(self): + assert self.CarCatalog.net_worth + assert self.CostumeCatalog.net_worth + assert self.Catalog.net_worth + assert not hasattr(self.CarCatalog.__table__.c, 'net_worth') + assert not hasattr(self.CostumeCatalog.__table__.c, 'net_worth') + + def test_assigns_aggregates_on_insert(self): + catalog = self.Catalog( + name=u'Some catalog' + ) + self.session.add(catalog) + self.session.commit() + product = self.Product( + name=u'Some product', + price=Decimal('1000'), + catalog=catalog + ) + self.session.add(product) + self.session.commit() + self.session.refresh(catalog) + assert catalog.net_worth == Decimal('1000') + + def test_assigns_aggregates_on_update(self): + catalog = self.Catalog( + name=u'Some catalog' + ) + self.session.add(catalog) + self.session.commit() + product = self.Product( + name=u'Some product', + price=Decimal('1000'), + catalog=catalog + ) + self.session.add(product) + self.session.commit() + product.price = Decimal('500') + self.session.commit() + self.session.refresh(catalog) + assert catalog.net_worth == Decimal('500') diff --git a/python-sqlalchemy-utils/tests/aggregate/test_m2m.py b/python-sqlalchemy-utils/tests/aggregate/test_m2m.py new file mode 100644 index 0000000..4a9ac45 --- /dev/null +++ b/python-sqlalchemy-utils/tests/aggregate/test_m2m.py @@ -0,0 +1,72 @@ +import sqlalchemy as sa + +from sqlalchemy_utils.aggregates import aggregated +from tests import TestCase + + +class TestAggregatesWithManyToManyRelationships(TestCase): + dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' + + def create_models(self): + user_group = sa.Table( + 'user_group', + self.Base.metadata, + sa.Column('user_id', sa.Integer, sa.ForeignKey('user.id')), + sa.Column('group_id', sa.Integer, sa.ForeignKey('group.id')) + ) + + class User(self.Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + @aggregated('groups', sa.Column(sa.Integer, default=0)) + def group_count(self): + return sa.func.count('1') + + groups = sa.orm.relationship( + 'Group', + backref='users', + secondary=user_group + ) + + class Group(self.Base): + __tablename__ = 'group' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + self.User = User + self.Group = Group + + def test_assigns_aggregates_on_insert(self): + user = self.User( + name=u'John Matrix' + ) + self.session.add(user) + self.session.commit() + group = self.Group( + name=u'Some group', + users=[user] + ) + self.session.add(group) + self.session.commit() + self.session.refresh(user) + assert user.group_count == 1 + + def test_updates_aggregates_on_delete(self): + user = self.User( + name=u'John Matrix' + ) + self.session.add(user) + self.session.commit() + group = self.Group( + name=u'Some group', + users=[user] + ) + self.session.add(group) + self.session.commit() + self.session.refresh(user) + user.groups = [] + self.session.commit() + self.session.refresh(user) + assert user.group_count == 0 diff --git a/python-sqlalchemy-utils/tests/aggregate/test_m2m_m2m.py b/python-sqlalchemy-utils/tests/aggregate/test_m2m_m2m.py new file mode 100644 index 0000000..1fee8c3 --- /dev/null +++ b/python-sqlalchemy-utils/tests/aggregate/test_m2m_m2m.py @@ -0,0 +1,80 @@ +import sqlalchemy as sa + +from sqlalchemy_utils import aggregated +from tests import TestCase + + +class TestAggregateManyToManyAndManyToMany(TestCase): + dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' + + def create_models(self): + catalog_products = sa.Table( + 'catalog_product', + self.Base.metadata, + sa.Column('catalog_id', sa.Integer, sa.ForeignKey('catalog.id')), + sa.Column('product_id', sa.Integer, sa.ForeignKey('product.id')) + ) + + product_categories = sa.Table( + 'category_product', + self.Base.metadata, + sa.Column('category_id', sa.Integer, sa.ForeignKey('category.id')), + sa.Column('product_id', sa.Integer, sa.ForeignKey('product.id')) + ) + + class Catalog(self.Base): + __tablename__ = 'catalog' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + @aggregated( + 'products.categories', + sa.Column(sa.Integer, default=0) + ) + def category_count(self): + return sa.func.count(sa.distinct(Category.id)) + + class Category(self.Base): + __tablename__ = 'category' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + class Product(self.Base): + __tablename__ = 'product' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + price = sa.Column(sa.Numeric) + + catalog_id = sa.Column( + sa.Integer, sa.ForeignKey('catalog.id') + ) + + catalogs = sa.orm.relationship( + Catalog, + backref='products', + secondary=catalog_products + ) + + categories = sa.orm.relationship( + Category, + backref='products', + secondary=product_categories + ) + + self.Catalog = Catalog + self.Category = Category + self.Product = Product + + def test_insert(self): + category = self.Category() + products = [ + self.Product(categories=[category]), + self.Product(categories=[category]) + ] + catalog = self.Catalog(products=products) + self.session.add(catalog) + catalog2 = self.Catalog(products=products) + self.session.add(catalog) + self.session.commit() + assert catalog.category_count == 1 + assert catalog2.category_count == 1 diff --git a/python-sqlalchemy-utils/tests/aggregate/test_multiple_aggregates_per_class.py b/python-sqlalchemy-utils/tests/aggregate/test_multiple_aggregates_per_class.py new file mode 100644 index 0000000..ee8b5dc --- /dev/null +++ b/python-sqlalchemy-utils/tests/aggregate/test_multiple_aggregates_per_class.py @@ -0,0 +1,81 @@ +import sqlalchemy as sa + +from sqlalchemy_utils.aggregates import aggregated +from tests import TestCase + + +class TestAggregateValueGenerationForSimpleModelPaths(TestCase): + def create_models(self): + class Thread(self.Base): + __tablename__ = 'thread' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + @aggregated( + 'comments', + sa.Column(sa.Integer, default=0) + ) + def comment_count(self): + return sa.func.count('1') + + @aggregated('comments', sa.Column(sa.Integer)) + def last_comment_id(self): + return sa.func.max(Comment.id) + + comments = sa.orm.relationship( + 'Comment', + backref='thread' + ) + + Thread.last_comment = sa.orm.relationship( + 'Comment', + primaryjoin='Thread.last_comment_id == Comment.id', + foreign_keys=[Thread.last_comment_id], + viewonly=True + ) + + class Comment(self.Base): + __tablename__ = 'comment' + id = sa.Column(sa.Integer, primary_key=True) + content = sa.Column(sa.Unicode(255)) + thread_id = sa.Column(sa.Integer, sa.ForeignKey('thread.id')) + + self.Thread = Thread + self.Comment = Comment + + def test_assigns_aggregates_on_insert(self): + thread = self.Thread() + thread.name = u'some article name' + self.session.add(thread) + comment = self.Comment(content=u'Some content', thread=thread) + self.session.add(comment) + self.session.commit() + self.session.refresh(thread) + assert thread.comment_count == 1 + assert thread.last_comment_id == comment.id + + def test_assigns_aggregates_on_separate_insert(self): + thread = self.Thread() + thread.name = u'some article name' + self.session.add(thread) + self.session.commit() + comment = self.Comment(content=u'Some content', thread=thread) + self.session.add(comment) + self.session.commit() + self.session.refresh(thread) + assert thread.comment_count == 1 + assert thread.last_comment_id == 1 + + def test_assigns_aggregates_on_delete(self): + thread = self.Thread() + thread.name = u'some article name' + self.session.add(thread) + self.session.commit() + comment = self.Comment(content=u'Some content', thread=thread) + self.session.add(comment) + self.session.commit() + self.session.delete(comment) + self.session.commit() + self.session.refresh(thread) + assert thread.comment_count == 0 + assert thread.last_comment_id is None diff --git a/python-sqlalchemy-utils/tests/aggregate/test_o2m_m2m.py b/python-sqlalchemy-utils/tests/aggregate/test_o2m_m2m.py new file mode 100644 index 0000000..c2b4ff5 --- /dev/null +++ b/python-sqlalchemy-utils/tests/aggregate/test_o2m_m2m.py @@ -0,0 +1,76 @@ +import sqlalchemy as sa + +from sqlalchemy_utils import aggregated +from tests import TestCase + + +class TestAggregateOneToManyAndManyToMany(TestCase): + dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' + + def create_models(self): + product_categories = sa.Table( + 'category_product', + self.Base.metadata, + sa.Column('category_id', sa.Integer, sa.ForeignKey('category.id')), + sa.Column('product_id', sa.Integer, sa.ForeignKey('product.id')) + ) + + class Catalog(self.Base): + __tablename__ = 'catalog' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + @aggregated( + 'products.categories', + sa.Column(sa.Integer, default=0) + ) + def category_count(self): + return sa.func.count(sa.distinct(Category.id)) + + class Category(self.Base): + __tablename__ = 'category' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + class Product(self.Base): + __tablename__ = 'product' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + price = sa.Column(sa.Numeric) + + catalog_id = sa.Column( + sa.Integer, sa.ForeignKey('catalog.id') + ) + + catalog = sa.orm.relationship( + Catalog, + backref='products' + ) + + categories = sa.orm.relationship( + Category, + backref='products', + secondary=product_categories + ) + + self.Catalog = Catalog + self.Category = Category + self.Product = Product + + def test_insert(self): + category = self.Category() + products = [ + self.Product(categories=[category]), + self.Product(categories=[category]) + ] + catalog = self.Catalog(products=products) + self.session.add(catalog) + products2 = [ + self.Product(categories=[category]), + self.Product(categories=[category]) + ] + catalog2 = self.Catalog(products=products2) + self.session.add(catalog) + self.session.commit() + assert catalog.category_count == 1 + assert catalog2.category_count == 1 diff --git a/python-sqlalchemy-utils/tests/aggregate/test_o2m_o2m.py b/python-sqlalchemy-utils/tests/aggregate/test_o2m_o2m.py new file mode 100644 index 0000000..9abed54 --- /dev/null +++ b/python-sqlalchemy-utils/tests/aggregate/test_o2m_o2m.py @@ -0,0 +1,64 @@ +from decimal import Decimal + +import sqlalchemy as sa + +from sqlalchemy_utils.aggregates import aggregated +from tests import TestCase + + +class TestAggregateOneToManyAndOneToMany(TestCase): + dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' + + def create_models(self): + class Catalog(self.Base): + __tablename__ = 'catalog' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + @aggregated( + 'categories.products', + sa.Column(sa.Integer, default=0) + ) + def product_count(self): + return sa.func.count('1') + + categories = sa.orm.relationship('Category', backref='catalog') + + class Category(self.Base): + __tablename__ = 'category' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id')) + + products = sa.orm.relationship('Product', backref='category') + + class Product(self.Base): + __tablename__ = 'product' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + price = sa.Column(sa.Numeric) + + category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id')) + + self.Catalog = Catalog + self.Category = Category + self.Product = Product + + def test_assigns_aggregates(self): + category = self.Category(name=u'Some category') + catalog = self.Catalog( + categories=[category] + ) + catalog.name = u'Some catalog' + self.session.add(catalog) + self.session.commit() + product = self.Product( + name=u'Some product', + price=Decimal('1000'), + category=category + ) + self.session.add(product) + self.session.commit() + self.session.refresh(catalog) + assert catalog.product_count == 1 diff --git a/python-sqlalchemy-utils/tests/aggregate/test_o2m_o2m_o2m.py b/python-sqlalchemy-utils/tests/aggregate/test_o2m_o2m_o2m.py new file mode 100644 index 0000000..672045b --- /dev/null +++ b/python-sqlalchemy-utils/tests/aggregate/test_o2m_o2m_o2m.py @@ -0,0 +1,88 @@ +import sqlalchemy as sa + +from sqlalchemy_utils import aggregated +from tests import TestCase + + +class Test3LevelDeepOneToMany(TestCase): + dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' + + def create_models(self): + class Catalog(self.Base): + __tablename__ = 'catalog' + id = sa.Column(sa.Integer, primary_key=True) + + @aggregated( + 'categories.sub_categories.products', + sa.Column(sa.Integer, default=0) + ) + def product_count(self): + return sa.func.count('1') + + categories = sa.orm.relationship('Category', backref='catalog') + + class Category(self.Base): + __tablename__ = 'category' + id = sa.Column(sa.Integer, primary_key=True) + catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id')) + + sub_categories = sa.orm.relationship( + 'SubCategory', backref='category' + ) + + class SubCategory(self.Base): + __tablename__ = 'sub_category' + id = sa.Column(sa.Integer, primary_key=True) + category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id')) + products = sa.orm.relationship('Product', backref='sub_category') + + class Product(self.Base): + __tablename__ = 'product' + id = sa.Column(sa.Integer, primary_key=True) + price = sa.Column(sa.Numeric) + + sub_category_id = sa.Column( + sa.Integer, sa.ForeignKey('sub_category.id') + ) + + self.Catalog = Catalog + self.Category = Category + self.SubCategory = SubCategory + self.Product = Product + + def test_assigns_aggregates(self): + catalog = self.catalog_factory() + self.session.commit() + self.session.refresh(catalog) + assert catalog.product_count == 1 + + def catalog_factory(self): + product = self.Product() + sub_category = self.SubCategory( + products=[product] + ) + category = self.Category(sub_categories=[sub_category]) + catalog = self.Catalog(categories=[category]) + self.session.add(catalog) + return catalog + + def test_only_updates_affected_aggregates(self): + catalog = self.catalog_factory() + catalog2 = self.catalog_factory() + self.session.commit() + + # force set catalog2 product_count to zero in order to check if it gets + # updated when the other catalog's product count gets updated + self.session.execute( + 'UPDATE catalog SET product_count = 0 WHERE id = %d' + % catalog2.id + ) + + catalog.categories[0].sub_categories[0].products.append( + self.Product() + ) + self.session.commit() + self.session.refresh(catalog) + self.session.refresh(catalog2) + assert catalog.product_count == 2 + assert catalog2.product_count == 0 diff --git a/python-sqlalchemy-utils/tests/aggregate/test_search_vectors.py b/python-sqlalchemy-utils/tests/aggregate/test_search_vectors.py new file mode 100644 index 0000000..d2a8621 --- /dev/null +++ b/python-sqlalchemy-utils/tests/aggregate/test_search_vectors.py @@ -0,0 +1,57 @@ +import sqlalchemy as sa + +from sqlalchemy_utils import aggregated, TSVectorType +from tests import TestCase + + +def tsvector_reduce_concat(vectors): + return sa.sql.expression.cast( + sa.func.coalesce( + sa.func.array_to_string(sa.func.array_agg(vectors), ' ') + ), + TSVectorType + ) + + +class TestSearchVectorAggregates(TestCase): + dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' + + def create_models(self): + class Catalog(self.Base): + __tablename__ = 'catalog' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + @aggregated('products', sa.Column(TSVectorType)) + def product_search_vector(self): + return tsvector_reduce_concat( + sa.func.to_tsvector(Product.name) + ) + + products = sa.orm.relationship('Product', backref='catalog') + + class Product(self.Base): + __tablename__ = 'product' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + price = sa.Column(sa.Numeric) + + catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id')) + + self.Catalog = Catalog + self.Product = Product + + def test_assigns_aggregates_on_insert(self): + catalog = self.Catalog( + name=u'Some catalog' + ) + self.session.add(catalog) + self.session.commit() + product = self.Product( + name=u'Product XYZ', + catalog=catalog + ) + self.session.add(product) + self.session.commit() + self.session.refresh(catalog) + assert catalog.product_search_vector == "'product':1 'xyz':2" diff --git a/python-sqlalchemy-utils/tests/aggregate/test_simple_paths.py b/python-sqlalchemy-utils/tests/aggregate/test_simple_paths.py new file mode 100644 index 0000000..1f1c071 --- /dev/null +++ b/python-sqlalchemy-utils/tests/aggregate/test_simple_paths.py @@ -0,0 +1,61 @@ +import sqlalchemy as sa + +from sqlalchemy_utils.aggregates import aggregated +from tests import TestCase + + +class TestAggregateValueGenerationForSimpleModelPaths(TestCase): + def create_models(self): + class Thread(self.Base): + __tablename__ = 'thread' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + @aggregated('comments', sa.Column(sa.Integer, default=0)) + def comment_count(self): + return sa.func.count('1') + + comments = sa.orm.relationship('Comment', backref='thread') + + class Comment(self.Base): + __tablename__ = 'comment' + id = sa.Column(sa.Integer, primary_key=True) + content = sa.Column(sa.Unicode(255)) + thread_id = sa.Column(sa.Integer, sa.ForeignKey('thread.id')) + + self.Thread = Thread + self.Comment = Comment + + def test_assigns_aggregates_on_insert(self): + thread = self.Thread() + thread.name = u'some article name' + self.session.add(thread) + comment = self.Comment(content=u'Some content', thread=thread) + self.session.add(comment) + self.session.commit() + self.session.refresh(thread) + assert thread.comment_count == 1 + + def test_assigns_aggregates_on_separate_insert(self): + thread = self.Thread() + thread.name = u'some article name' + self.session.add(thread) + self.session.commit() + comment = self.Comment(content=u'Some content', thread=thread) + self.session.add(comment) + self.session.commit() + self.session.refresh(thread) + assert thread.comment_count == 1 + + def test_assigns_aggregates_on_delete(self): + thread = self.Thread() + thread.name = u'some article name' + self.session.add(thread) + self.session.commit() + comment = self.Comment(content=u'Some content', thread=thread) + self.session.add(comment) + self.session.commit() + self.session.delete(comment) + self.session.commit() + self.session.refresh(thread) + assert thread.comment_count == 0 diff --git a/python-sqlalchemy-utils/tests/aggregate/test_with_column_alias.py b/python-sqlalchemy-utils/tests/aggregate/test_with_column_alias.py new file mode 100644 index 0000000..80a87e5 --- /dev/null +++ b/python-sqlalchemy-utils/tests/aggregate/test_with_column_alias.py @@ -0,0 +1,59 @@ +import sqlalchemy as sa + +from sqlalchemy_utils.aggregates import aggregated +from tests import TestCase + + +class TestAggregatedWithColumnAlias(TestCase): + def create_models(self): + class Thread(self.Base): + __tablename__ = 'thread' + id = sa.Column(sa.Integer, primary_key=True) + + @aggregated( + 'comments', + sa.Column('_comment_count', sa.Integer, default=0) + ) + def comment_count(self): + return sa.func.count('1') + + comments = sa.orm.relationship('Comment', backref='thread') + + class Comment(self.Base): + __tablename__ = 'comment' + id = sa.Column(sa.Integer, primary_key=True) + thread_id = sa.Column(sa.Integer, sa.ForeignKey('thread.id')) + + self.Thread = Thread + self.Comment = Comment + + def test_assigns_aggregates_on_insert(self): + thread = self.Thread() + self.session.add(thread) + comment = self.Comment(thread=thread) + self.session.add(comment) + self.session.commit() + self.session.refresh(thread) + assert thread.comment_count == 1 + + def test_assigns_aggregates_on_separate_insert(self): + thread = self.Thread() + self.session.add(thread) + self.session.commit() + comment = self.Comment(thread=thread) + self.session.add(comment) + self.session.commit() + self.session.refresh(thread) + assert thread.comment_count == 1 + + def test_assigns_aggregates_on_delete(self): + thread = self.Thread() + self.session.add(thread) + self.session.commit() + comment = self.Comment(thread=thread) + self.session.add(comment) + self.session.commit() + self.session.delete(comment) + self.session.commit() + self.session.refresh(thread) + assert thread.comment_count == 0 diff --git a/python-sqlalchemy-utils/tests/aggregate/test_with_ondelete_cascade.py b/python-sqlalchemy-utils/tests/aggregate/test_with_ondelete_cascade.py new file mode 100644 index 0000000..93a26d0 --- /dev/null +++ b/python-sqlalchemy-utils/tests/aggregate/test_with_ondelete_cascade.py @@ -0,0 +1,47 @@ +import sqlalchemy as sa + +from sqlalchemy_utils.aggregates import aggregated +from tests import TestCase + + +class TestAggregateValueGenerationWithCascadeDelete(TestCase): + dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' + + def create_models(self): + class Thread(self.Base): + __tablename__ = 'thread' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + @aggregated('comments', sa.Column(sa.Integer, default=0)) + def comment_count(self): + return sa.func.count('1') + + comments = sa.orm.relationship( + 'Comment', + passive_deletes=True, + backref='thread' + ) + + class Comment(self.Base): + __tablename__ = 'comment' + id = sa.Column(sa.Integer, primary_key=True) + content = sa.Column(sa.Unicode(255)) + thread_id = sa.Column( + sa.Integer, + sa.ForeignKey('thread.id', ondelete='CASCADE') + ) + + self.Thread = Thread + self.Comment = Comment + + def test_something(self): + thread = self.Thread() + thread.name = u'some article name' + self.session.add(thread) + comment = self.Comment(content=u'Some content', thread=thread) + self.session.add(comment) + self.session.commit() + self.session.expire_all() + self.session.delete(thread) + self.session.commit() diff --git a/python-sqlalchemy-utils/tests/functions/__init__.py b/python-sqlalchemy-utils/tests/functions/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/python-sqlalchemy-utils/tests/functions/test_analyze.py b/python-sqlalchemy-utils/tests/functions/test_analyze.py new file mode 100644 index 0000000..1633efd --- /dev/null +++ b/python-sqlalchemy-utils/tests/functions/test_analyze.py @@ -0,0 +1,29 @@ +from sqlalchemy_utils import analyze +from tests import TestCase + + +class TestAnalyzeWithPostgres(TestCase): + dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' + + def test_runtime(self): + query = self.session.query(self.Article) + assert analyze(self.connection, query).runtime + + def test_node_types_with_join(self): + query = ( + self.session.query(self.Article) + .join(self.Article.category) + ) + analysis = analyze(self.connection, query) + assert analysis.node_types == [ + u'Hash Join', u'Seq Scan', u'Hash', u'Seq Scan' + ] + + def test_node_types_with_index_only_scan(self): + query = ( + self.session.query(self.Article.name) + .order_by(self.Article.name) + .limit(10) + ) + analysis = analyze(self.connection, query) + assert analysis.node_types == [u'Limit', u'Index Only Scan'] diff --git a/python-sqlalchemy-utils/tests/functions/test_database.py b/python-sqlalchemy-utils/tests/functions/test_database.py new file mode 100644 index 0000000..ab14b85 --- /dev/null +++ b/python-sqlalchemy-utils/tests/functions/test_database.py @@ -0,0 +1,81 @@ +import os + +import sqlalchemy as sa +from flexmock import flexmock +from pytest import mark + +from sqlalchemy_utils import create_database, database_exists, drop_database +from tests import TestCase + +pymysql = None +try: + import pymysql # noqa +except ImportError: + pass + + +class DatabaseTest(TestCase): + def test_create_and_drop(self): + assert not database_exists(self.url) + create_database(self.url) + assert database_exists(self.url) + drop_database(self.url) + assert not database_exists(self.url) + + +class TestDatabaseSQLite(DatabaseTest): + url = 'sqlite:///sqlalchemy_utils.db' + + def setup(self): + if os.path.exists('sqlalchemy_utils.db'): + os.remove('sqlalchemy_utils.db') + + def test_exists_memory(self): + assert database_exists('sqlite:///:memory:') + + +@mark.skipif('pymysql is None') +class TestDatabaseMySQL(DatabaseTest): + url = 'mysql+pymysql://travis@localhost/db_test_sqlalchemy_util' + + +@mark.skipif('pymysql is None') +class TestDatabaseMySQLWithQuotedName(DatabaseTest): + url = 'mysql+pymysql://travis@localhost/db_test_sqlalchemy-util' + + +class TestDatabasePostgresWithQuotedName(DatabaseTest): + url = 'postgres://postgres@localhost/db_test_sqlalchemy-util' + + def test_template(self): + ( + flexmock(sa.engine.Engine) + .should_receive('execute') + .with_args( + '''CREATE DATABASE "db_test_sqlalchemy-util"''' + " ENCODING 'utf8' " + 'TEMPLATE "my-template"' + ) + ) + create_database( + 'postgres://postgres@localhost/db_test_sqlalchemy-util', + template='my-template' + ) + + +class TestDatabasePostgres(DatabaseTest): + url = 'postgres://postgres@localhost/db_test_sqlalchemy_util' + + def test_template(self): + ( + flexmock(sa.engine.Engine) + .should_receive('execute') + .with_args( + "CREATE DATABASE db_test_sqlalchemy_util ENCODING 'utf8' " + "TEMPLATE my_template" + ) + ) + create_database( + 'postgres://postgres@localhost/db_test_sqlalchemy_util', + template='my_template' + ) diff --git a/python-sqlalchemy-utils/tests/functions/test_dependent_objects.py b/python-sqlalchemy-utils/tests/functions/test_dependent_objects.py new file mode 100644 index 0000000..0702dd7 --- /dev/null +++ b/python-sqlalchemy-utils/tests/functions/test_dependent_objects.py @@ -0,0 +1,295 @@ +import sqlalchemy as sa + +from sqlalchemy_utils import dependent_objects, get_referencing_foreign_keys +from tests import TestCase + + +class TestDependentObjects(TestCase): + def create_models(self): + class User(self.Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + first_name = sa.Column(sa.Unicode(255)) + last_name = sa.Column(sa.Unicode(255)) + + class Article(self.Base): + __tablename__ = 'article' + id = sa.Column(sa.Integer, primary_key=True) + author_id = sa.Column(sa.Integer, sa.ForeignKey('user.id')) + owner_id = sa.Column( + sa.Integer, sa.ForeignKey('user.id', ondelete='SET NULL') + ) + + author = sa.orm.relationship(User, foreign_keys=[author_id]) + owner = sa.orm.relationship(User, foreign_keys=[owner_id]) + + class BlogPost(self.Base): + __tablename__ = 'blog_post' + id = sa.Column(sa.Integer, primary_key=True) + owner_id = sa.Column( + sa.Integer, sa.ForeignKey('user.id', ondelete='CASCADE') + ) + + owner = sa.orm.relationship(User) + + self.User = User + self.Article = Article + self.BlogPost = BlogPost + + def test_returns_all_dependent_objects(self): + user = self.User(first_name=u'John') + articles = [ + self.Article(author=user), + self.Article(), + self.Article(owner=user), + self.Article(author=user, owner=user) + ] + self.session.add_all(articles) + self.session.commit() + + deps = list(dependent_objects(user)) + assert len(deps) == 3 + assert articles[0] in deps + assert articles[2] in deps + assert articles[3] in deps + + def test_with_foreign_keys_parameter(self): + user = self.User(first_name=u'John') + objects = [ + self.Article(author=user), + self.Article(), + self.Article(owner=user), + self.Article(author=user, owner=user), + self.BlogPost(owner=user) + ] + self.session.add_all(objects) + self.session.commit() + + deps = list( + dependent_objects( + user, + ( + fk for fk in get_referencing_foreign_keys(self.User) + if fk.ondelete == 'RESTRICT' or fk.ondelete is None + ) + ).limit(5) + ) + assert len(deps) == 2 + assert objects[0] in deps + assert objects[3] in deps + + +class TestDependentObjectsWithColumnAliases(TestCase): + def create_models(self): + class User(self.Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + first_name = sa.Column(sa.Unicode(255)) + last_name = sa.Column(sa.Unicode(255)) + + class Article(self.Base): + __tablename__ = 'article' + id = sa.Column(sa.Integer, primary_key=True) + author_id = sa.Column( + '_author_id', sa.Integer, sa.ForeignKey('user.id') + ) + owner_id = sa.Column( + '_owner_id', + sa.Integer, sa.ForeignKey('user.id', ondelete='SET NULL') + ) + + author = sa.orm.relationship(User, foreign_keys=[author_id]) + owner = sa.orm.relationship(User, foreign_keys=[owner_id]) + + class BlogPost(self.Base): + __tablename__ = 'blog_post' + id = sa.Column(sa.Integer, primary_key=True) + owner_id = sa.Column( + '_owner_id', + sa.Integer, sa.ForeignKey('user.id', ondelete='CASCADE') + ) + + owner = sa.orm.relationship(User) + + self.User = User + self.Article = Article + self.BlogPost = BlogPost + + def test_returns_all_dependent_objects(self): + user = self.User(first_name=u'John') + articles = [ + self.Article(author=user), + self.Article(), + self.Article(owner=user), + self.Article(author=user, owner=user) + ] + self.session.add_all(articles) + self.session.commit() + + deps = list(dependent_objects(user)) + assert len(deps) == 3 + assert articles[0] in deps + assert articles[2] in deps + assert articles[3] in deps + + def test_with_foreign_keys_parameter(self): + user = self.User(first_name=u'John') + objects = [ + self.Article(author=user), + self.Article(), + self.Article(owner=user), + self.Article(author=user, owner=user), + self.BlogPost(owner=user) + ] + self.session.add_all(objects) + self.session.commit() + + deps = list( + dependent_objects( + user, + ( + fk for fk in get_referencing_foreign_keys(self.User) + if fk.ondelete == 'RESTRICT' or fk.ondelete is None + ) + ).limit(5) + ) + assert len(deps) == 2 + assert objects[0] in deps + assert objects[3] in deps + + +class TestDependentObjectsWithManyReferences(TestCase): + def create_models(self): + class User(self.Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + first_name = sa.Column(sa.Unicode(255)) + last_name = sa.Column(sa.Unicode(255)) + + class BlogPost(self.Base): + __tablename__ = 'blog_post' + id = sa.Column(sa.Integer, primary_key=True) + author_id = sa.Column(sa.Integer, sa.ForeignKey('user.id')) + author = sa.orm.relationship(User) + + class Article(self.Base): + __tablename__ = 'article' + id = sa.Column(sa.Integer, primary_key=True) + author_id = sa.Column(sa.Integer, sa.ForeignKey('user.id')) + author = sa.orm.relationship(User) + + self.User = User + self.Article = Article + self.BlogPost = BlogPost + + def test_with_many_dependencies(self): + user = self.User(first_name=u'John') + objects = [ + self.Article(author=user), + self.BlogPost(author=user) + ] + self.session.add_all(objects) + self.session.commit() + deps = list(dependent_objects(user)) + assert len(deps) == 2 + + +class TestDependentObjectsWithCompositeKeys(TestCase): + def create_models(self): + class User(self.Base): + __tablename__ = 'user' + first_name = sa.Column(sa.Unicode(255), primary_key=True) + last_name = sa.Column(sa.Unicode(255), primary_key=True) + + class Article(self.Base): + __tablename__ = 'article' + id = sa.Column(sa.Integer, primary_key=True) + author_first_name = sa.Column(sa.Unicode(255)) + author_last_name = sa.Column(sa.Unicode(255)) + __table_args__ = ( + sa.ForeignKeyConstraint( + [author_first_name, author_last_name], + [User.first_name, User.last_name] + ), + ) + + author = sa.orm.relationship(User) + + self.User = User + self.Article = Article + + def test_returns_all_dependent_objects(self): + user = self.User(first_name=u'John', last_name=u'Smith') + articles = [ + self.Article(author=user), + self.Article(), + self.Article(), + self.Article(author=user) + ] + self.session.add_all(articles) + self.session.commit() + + deps = list(dependent_objects(user)) + assert len(deps) == 2 + assert articles[0] in deps + assert articles[3] in deps + + +class TestDependentObjectsWithSingleTableInheritance(TestCase): + def create_models(self): + class Category(self.Base): + __tablename__ = 'category' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + class TextItem(self.Base): + __tablename__ = 'text_item' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + category_id = sa.Column( + sa.Integer, + sa.ForeignKey(Category.id) + ) + category = sa.orm.relationship( + Category, + backref=sa.orm.backref( + 'articles' + ) + ) + type = sa.Column(sa.Unicode(255)) + + __mapper_args__ = { + 'polymorphic_on': type, + } + + class Article(TextItem): + __mapper_args__ = { + 'polymorphic_identity': u'article' + } + + class BlogPost(TextItem): + __mapper_args__ = { + 'polymorphic_identity': u'blog_post' + } + + self.Category = Category + self.TextItem = TextItem + self.Article = Article + self.BlogPost = BlogPost + + def test_returns_all_dependent_objects(self): + category1 = self.Category(name=u'Category #1') + category2 = self.Category(name=u'Category #2') + articles = [ + self.Article(category=category1), + self.Article(category=category1), + self.Article(category=category2), + self.Article(category=category2), + ] + self.session.add_all(articles) + self.session.commit() + + deps = list(dependent_objects(category1)) + assert len(deps) == 2 + assert articles[0] in deps + assert articles[1] in deps diff --git a/python-sqlalchemy-utils/tests/functions/test_escape_like.py b/python-sqlalchemy-utils/tests/functions/test_escape_like.py new file mode 100644 index 0000000..d1f78ca --- /dev/null +++ b/python-sqlalchemy-utils/tests/functions/test_escape_like.py @@ -0,0 +1,7 @@ +from sqlalchemy_utils import escape_like +from tests import TestCase + + +class TestEscapeLike(TestCase): + def test_escapes_wildcards(self): + assert escape_like('_*%') == '*_***%' diff --git a/python-sqlalchemy-utils/tests/functions/test_get_bind.py b/python-sqlalchemy-utils/tests/functions/test_get_bind.py new file mode 100644 index 0000000..c10a5f0 --- /dev/null +++ b/python-sqlalchemy-utils/tests/functions/test_get_bind.py @@ -0,0 +1,21 @@ +from pytest import raises + +from sqlalchemy_utils import get_bind +from tests import TestCase + + +class TestGetBind(TestCase): + def test_with_session(self): + assert get_bind(self.session) == self.connection + + def test_with_connection(self): + assert get_bind(self.connection) == self.connection + + def test_with_model_object(self): + article = self.Article() + self.session.add(article) + assert get_bind(article) == self.connection + + def test_with_unknown_type(self): + with raises(TypeError): + get_bind(None) diff --git a/python-sqlalchemy-utils/tests/functions/test_get_class_by_table.py b/python-sqlalchemy-utils/tests/functions/test_get_class_by_table.py new file mode 100644 index 0000000..99db410 --- /dev/null +++ b/python-sqlalchemy-utils/tests/functions/test_get_class_by_table.py @@ -0,0 +1,99 @@ +import sqlalchemy as sa +from pytest import raises +from sqlalchemy.ext.declarative import declarative_base + +from sqlalchemy_utils import get_class_by_table + + +class TestGetClassByTableWithJoinedTableInheritance(object): + def setup_method(self, method): + self.Base = declarative_base() + + class Entity(self.Base): + __tablename__ = 'entity' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.String) + type = sa.Column(sa.String) + __mapper_args__ = { + 'polymorphic_on': type, + 'polymorphic_identity': 'entity' + } + + class User(Entity): + __tablename__ = 'user' + id = sa.Column( + sa.Integer, + sa.ForeignKey(Entity.id, ondelete='CASCADE'), + primary_key=True + ) + __mapper_args__ = { + 'polymorphic_identity': 'user' + } + + self.Entity = Entity + self.User = User + + def test_returns_class(self): + assert get_class_by_table(self.Base, self.User.__table__) == self.User + assert get_class_by_table( + self.Base, + self.Entity.__table__ + ) == self.Entity + + def test_table_with_no_associated_class(self): + table = sa.Table( + 'some_table', + self.Base.metadata, + sa.Column('id', sa.Integer) + ) + assert get_class_by_table(self.Base, table) is None + + +class TestGetClassByTableWithSingleTableInheritance(object): + def setup_method(self, method): + self.Base = declarative_base() + + class Entity(self.Base): + __tablename__ = 'entity' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.String) + type = sa.Column(sa.String) + __mapper_args__ = { + 'polymorphic_on': type, + 'polymorphic_identity': 'entity' + } + + class User(Entity): + __mapper_args__ = { + 'polymorphic_identity': 'user' + } + + self.Entity = Entity + self.User = User + + def test_multiple_classes_without_data_parameter(self): + with raises(ValueError): + assert get_class_by_table( + self.Base, + self.Entity.__table__ + ) + + def test_multiple_classes_with_data_parameter(self): + assert get_class_by_table( + self.Base, + self.Entity.__table__, + {'type': 'entity'} + ) == self.Entity + assert get_class_by_table( + self.Base, + self.Entity.__table__, + {'type': 'user'} + ) == self.User + + def test_multiple_classes_with_bogus_data(self): + with raises(ValueError): + assert get_class_by_table( + self.Base, + self.Entity.__table__, + {'type': 'unknown'} + ) diff --git a/python-sqlalchemy-utils/tests/functions/test_get_column_key.py b/python-sqlalchemy-utils/tests/functions/test_get_column_key.py new file mode 100644 index 0000000..486483c --- /dev/null +++ b/python-sqlalchemy-utils/tests/functions/test_get_column_key.py @@ -0,0 +1,44 @@ +from copy import copy + +import sqlalchemy as sa +from pytest import raises +from sqlalchemy.ext.declarative import declarative_base + +from sqlalchemy_utils import get_column_key + + +class TestGetColumnKey(object): + def setup_method(self, method): + Base = declarative_base() + + class Building(Base): + __tablename__ = 'building' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column('_name', sa.Unicode(255)) + + class Movie(Base): + __tablename__ = 'movie' + id = sa.Column(sa.Integer, primary_key=True) + + self.Building = Building + self.Movie = Movie + + def test_supports_aliases(self): + assert ( + get_column_key(self.Building, self.Building.__table__.c.id) + == + 'id' + ) + assert ( + get_column_key(self.Building, self.Building.__table__.c._name) + == + 'name' + ) + + def test_supports_vague_matching_of_column_objects(self): + column = copy(self.Building.__table__.c._name) + assert get_column_key(self.Building, column) == 'name' + + def test_throws_value_error_for_unknown_column(self): + with raises(sa.orm.exc.UnmappedColumnError): + get_column_key(self.Building, self.Movie.__table__.c.id) diff --git a/python-sqlalchemy-utils/tests/functions/test_get_columns.py b/python-sqlalchemy-utils/tests/functions/test_get_columns.py new file mode 100644 index 0000000..1ed3e3a --- /dev/null +++ b/python-sqlalchemy-utils/tests/functions/test_get_columns.py @@ -0,0 +1,53 @@ +import sqlalchemy as sa +from sqlalchemy.ext.declarative import declarative_base + +from sqlalchemy_utils import get_columns + + +class TestGetColumns(object): + def setup_method(self, method): + Base = declarative_base() + + class Building(Base): + __tablename__ = 'building' + id = sa.Column('_id', sa.Integer, primary_key=True) + name = sa.Column('_name', sa.Unicode(255)) + + self.Building = Building + + def test_table(self): + assert isinstance( + get_columns(self.Building.__table__), + sa.sql.base.ImmutableColumnCollection + ) + + def test_declarative_class(self): + assert isinstance( + get_columns(self.Building), + sa.util._collections.OrderedProperties + ) + + def test_declarative_object(self): + assert isinstance( + get_columns(self.Building()), + sa.util._collections.OrderedProperties + ) + + def test_mapper(self): + assert isinstance( + get_columns(self.Building.__mapper__), + sa.util._collections.OrderedProperties + ) + + def test_class_alias(self): + assert isinstance( + get_columns(sa.orm.aliased(self.Building)), + sa.util._collections.OrderedProperties + ) + + def test_table_alias(self): + alias = sa.orm.aliased(self.Building.__table__) + assert isinstance( + get_columns(alias), + sa.sql.base.ImmutableColumnCollection + ) diff --git a/python-sqlalchemy-utils/tests/functions/test_get_hybrid_properties.py b/python-sqlalchemy-utils/tests/functions/test_get_hybrid_properties.py new file mode 100644 index 0000000..a37a664 --- /dev/null +++ b/python-sqlalchemy-utils/tests/functions/test_get_hybrid_properties.py @@ -0,0 +1,37 @@ +import sqlalchemy as sa +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.ext.hybrid import hybrid_property + +from sqlalchemy_utils import get_hybrid_properties + + +class TestGetHybridProperties(object): + def setup_method(self, method): + Base = declarative_base() + + class Category(Base): + __tablename__ = 'category' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + @hybrid_property + def lowercase_name(self): + return self.name.lower() + + @lowercase_name.expression + def lowercase_name(cls): + return sa.func.lower(cls.name) + + self.Category = Category + + def test_declarative_model(self): + assert ( + list(get_hybrid_properties(self.Category).keys()) == + ['lowercase_name'] + ) + + def test_mapper(self): + assert ( + list(get_hybrid_properties(sa.inspect(self.Category)).keys()) == + ['lowercase_name'] + ) diff --git a/python-sqlalchemy-utils/tests/functions/test_get_mapper.py b/python-sqlalchemy-utils/tests/functions/test_get_mapper.py new file mode 100644 index 0000000..766b495 --- /dev/null +++ b/python-sqlalchemy-utils/tests/functions/test_get_mapper.py @@ -0,0 +1,134 @@ +import sqlalchemy as sa +from pytest import raises +from sqlalchemy.ext.declarative import declarative_base + +from sqlalchemy_utils import get_mapper +from tests import TestCase + + +class TestGetMapper(object): + def setup_method(self, method): + self.Base = declarative_base() + + class Building(self.Base): + __tablename__ = 'building' + id = sa.Column(sa.Integer, primary_key=True) + + self.Building = Building + + def test_table(self): + assert get_mapper(self.Building.__table__) == sa.inspect(self.Building) + + def test_declarative_class(self): + assert ( + get_mapper(self.Building) == + sa.inspect(self.Building) + ) + + def test_declarative_object(self): + assert ( + get_mapper(self.Building()) == + sa.inspect(self.Building) + ) + + def test_mapper(self): + assert ( + get_mapper(self.Building.__mapper__) == + sa.inspect(self.Building) + ) + + def test_class_alias(self): + assert ( + get_mapper(sa.orm.aliased(self.Building)) == + sa.inspect(self.Building) + ) + + def test_instrumented_attribute(self): + assert ( + get_mapper(self.Building.id) == sa.inspect(self.Building) + ) + + def test_table_alias(self): + alias = sa.orm.aliased(self.Building.__table__) + assert ( + get_mapper(alias) == + sa.inspect(self.Building) + ) + + def test_column(self): + assert ( + get_mapper(self.Building.__table__.c.id) == + sa.inspect(self.Building) + ) + + def test_column_of_an_alias(self): + assert ( + get_mapper(sa.orm.aliased(self.Building.__table__).c.id) == + sa.inspect(self.Building) + ) + + +class TestGetMapperWithQueryEntities(TestCase): + def create_models(self): + class Building(self.Base): + __tablename__ = 'building' + id = sa.Column(sa.Integer, primary_key=True) + + self.Building = Building + + def test_mapper_entity_with_mapper(self): + entity = self.session.query(self.Building.__mapper__)._entities[0] + assert ( + get_mapper(entity) == + sa.inspect(self.Building) + ) + + def test_mapper_entity_with_class(self): + entity = self.session.query(self.Building)._entities[0] + assert ( + get_mapper(entity) == + sa.inspect(self.Building) + ) + + def test_column_entity(self): + query = self.session.query(self.Building.id) + assert get_mapper(query._entities[0]) == sa.inspect(self.Building) + + +class TestGetMapperWithMultipleMappersFound(object): + def setup_method(self, method): + Base = declarative_base() + + class Building(Base): + __tablename__ = 'building' + id = sa.Column(sa.Integer, primary_key=True) + + class BigBuilding(Building): + pass + + self.Building = Building + self.BigBuilding = BigBuilding + + def test_table(self): + with raises(ValueError): + get_mapper(self.Building.__table__) + + def test_table_alias(self): + alias = sa.orm.aliased(self.Building.__table__) + with raises(ValueError): + get_mapper(alias) + + +class TestGetMapperForTableWithoutMapper(object): + def setup_method(self, method): + metadata = sa.MetaData() + self.building = sa.Table('building', metadata) + + def test_table(self): + with raises(ValueError): + get_mapper(self.building) + + def test_table_alias(self): + alias = sa.orm.aliased(self.building) + with raises(ValueError): + get_mapper(alias) diff --git a/python-sqlalchemy-utils/tests/functions/test_get_primary_keys.py b/python-sqlalchemy-utils/tests/functions/test_get_primary_keys.py new file mode 100644 index 0000000..73525be --- /dev/null +++ b/python-sqlalchemy-utils/tests/functions/test_get_primary_keys.py @@ -0,0 +1,48 @@ +import sqlalchemy as sa +from sqlalchemy.ext.declarative import declarative_base + +from sqlalchemy_utils import get_primary_keys + +try: + from collections import OrderedDict +except ImportError: + from ordereddict import OrderedDict + + +class TestGetPrimaryKeys(object): + def setup_method(self, method): + Base = declarative_base() + + class Building(Base): + __tablename__ = 'building' + id = sa.Column('_id', sa.Integer, primary_key=True) + name = sa.Column('_name', sa.Unicode(255)) + + self.Building = Building + + def test_table(self): + assert get_primary_keys(self.Building.__table__) == OrderedDict({ + '_id': self.Building.__table__.c._id + }) + + def test_declarative_class(self): + assert get_primary_keys(self.Building) == OrderedDict({ + 'id': self.Building.__table__.c._id + }) + + def test_declarative_object(self): + assert get_primary_keys(self.Building()) == OrderedDict({ + 'id': self.Building.__table__.c._id + }) + + def test_class_alias(self): + alias = sa.orm.aliased(self.Building) + assert get_primary_keys(alias) == OrderedDict({ + 'id': self.Building.__table__.c._id + }) + + def test_table_alias(self): + alias = sa.orm.aliased(self.Building.__table__) + assert get_primary_keys(alias) == OrderedDict({ + '_id': alias.c._id + }) diff --git a/python-sqlalchemy-utils/tests/functions/test_get_query_entities.py b/python-sqlalchemy-utils/tests/functions/test_get_query_entities.py new file mode 100644 index 0000000..f1ae7c9 --- /dev/null +++ b/python-sqlalchemy-utils/tests/functions/test_get_query_entities.py @@ -0,0 +1,102 @@ +import sqlalchemy as sa + +from sqlalchemy_utils import get_query_entities +from tests import TestCase + + +class TestGetQueryEntities(TestCase): + def create_models(self): + class TextItem(self.Base): + __tablename__ = 'text_item' + id = sa.Column(sa.Integer, primary_key=True) + + type = sa.Column(sa.Unicode(255)) + + __mapper_args__ = { + 'polymorphic_on': type, + } + + class Article(TextItem): + __tablename__ = 'article' + id = sa.Column( + sa.Integer, sa.ForeignKey(TextItem.id), primary_key=True + ) + category = sa.Column(sa.Unicode(255)) + __mapper_args__ = { + 'polymorphic_identity': u'article' + } + + class BlogPost(TextItem): + __tablename__ = 'blog_post' + id = sa.Column( + sa.Integer, sa.ForeignKey(TextItem.id), primary_key=True + ) + __mapper_args__ = { + 'polymorphic_identity': u'blog_post' + } + + self.TextItem = TextItem + self.Article = Article + self.BlogPost = BlogPost + + def test_mapper(self): + query = self.session.query(sa.inspect(self.TextItem)) + assert get_query_entities(query) == [self.TextItem] + + def test_entity(self): + query = self.session.query(self.TextItem) + assert get_query_entities(query) == [self.TextItem] + + def test_instrumented_attribute(self): + query = self.session.query(self.TextItem.id) + assert get_query_entities(query) == [self.TextItem] + + def test_column(self): + query = self.session.query(self.TextItem.__table__.c.id) + assert get_query_entities(query) == [self.TextItem.__table__] + + def test_aliased_selectable(self): + selectable = sa.orm.with_polymorphic(self.TextItem, [self.BlogPost]) + query = self.session.query(selectable) + assert get_query_entities(query) == [selectable] + + def test_joined_entity(self): + query = self.session.query(self.TextItem).join( + self.BlogPost, self.BlogPost.id == self.TextItem.id + ) + assert get_query_entities(query) == [ + self.TextItem, sa.inspect(self.BlogPost) + ] + + def test_joined_aliased_entity(self): + alias = sa.orm.aliased(self.BlogPost) + + query = self.session.query(self.TextItem).join( + alias, alias.id == self.TextItem.id + ) + assert get_query_entities(query) == [self.TextItem, alias] + + def test_column_entity_with_label(self): + query = self.session.query(self.Article.id.label('id')) + assert get_query_entities(query) == [self.Article] + + def test_with_subquery(self): + number_of_articles = ( + sa.select( + [sa.func.count(self.Article.id)], + ) + .select_from( + self.Article.__table__ + ) + ).label('number_of_articles') + + query = self.session.query(self.Article, number_of_articles) + assert get_query_entities(query) == [ + self.Article, + number_of_articles + ] + + def test_aliased_entity(self): + alias = sa.orm.aliased(self.Article) + query = self.session.query(alias) + assert get_query_entities(query) == [alias] diff --git a/python-sqlalchemy-utils/tests/functions/test_get_referencing_foreign_keys.py b/python-sqlalchemy-utils/tests/functions/test_get_referencing_foreign_keys.py new file mode 100644 index 0000000..dc8a5de --- /dev/null +++ b/python-sqlalchemy-utils/tests/functions/test_get_referencing_foreign_keys.py @@ -0,0 +1,86 @@ +import sqlalchemy as sa + +from sqlalchemy_utils import get_referencing_foreign_keys +from tests import TestCase + + +class TestGetReferencingFksWithCompositeKeys(TestCase): + def create_models(self): + class User(self.Base): + __tablename__ = 'user' + first_name = sa.Column(sa.Unicode(255), primary_key=True) + last_name = sa.Column(sa.Unicode(255), primary_key=True) + + class Article(self.Base): + __tablename__ = 'article' + id = sa.Column(sa.Integer, primary_key=True) + author_first_name = sa.Column(sa.Unicode(255)) + author_last_name = sa.Column(sa.Unicode(255)) + __table_args__ = ( + sa.ForeignKeyConstraint( + [author_first_name, author_last_name], + [User.first_name, User.last_name] + ), + ) + + self.User = User + self.Article = Article + + def test_with_declarative_class(self): + fks = get_referencing_foreign_keys(self.User) + assert self.Article.__table__.foreign_keys == fks + + def test_with_table(self): + fks = get_referencing_foreign_keys(self.User.__table__) + assert self.Article.__table__.foreign_keys == fks + + +class TestGetReferencingFksWithInheritance(TestCase): + def create_models(self): + class User(self.Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + type = sa.Column(sa.Unicode) + first_name = sa.Column(sa.Unicode(255)) + last_name = sa.Column(sa.Unicode(255)) + + __mapper_args__ = { + 'polymorphic_on': 'type' + } + + class Admin(User): + __tablename__ = 'admin' + id = sa.Column( + sa.Integer, sa.ForeignKey(User.id), primary_key=True + ) + + class TextItem(self.Base): + __tablename__ = 'textitem' + id = sa.Column(sa.Integer, primary_key=True) + type = sa.Column(sa.Unicode) + author_id = sa.Column(sa.Integer, sa.ForeignKey(User.id)) + __mapper_args__ = { + 'polymorphic_on': 'type' + } + + class Article(TextItem): + __tablename__ = 'article' + id = sa.Column( + sa.Integer, sa.ForeignKey(TextItem.id), primary_key=True + ) + __mapper_args__ = { + 'polymorphic_identity': 'article' + } + + self.Admin = Admin + self.User = User + self.Article = Article + self.TextItem = TextItem + + def test_with_declarative_class(self): + fks = get_referencing_foreign_keys(self.Admin) + assert self.TextItem.__table__.foreign_keys == fks + + def test_with_table(self): + fks = get_referencing_foreign_keys(self.Admin.__table__) + assert fks == set([]) diff --git a/python-sqlalchemy-utils/tests/functions/test_get_tables.py b/python-sqlalchemy-utils/tests/functions/test_get_tables.py new file mode 100644 index 0000000..682e782 --- /dev/null +++ b/python-sqlalchemy-utils/tests/functions/test_get_tables.py @@ -0,0 +1,76 @@ +import sqlalchemy as sa + +from sqlalchemy_utils import get_tables +from tests import TestCase + + +class TestGetTables(TestCase): + def create_models(self): + class TextItem(self.Base): + __tablename__ = 'text_item' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + type = sa.Column(sa.Unicode(255)) + + __mapper_args__ = { + 'polymorphic_on': type, + 'with_polymorphic': '*' + } + + class Article(TextItem): + __tablename__ = 'article' + id = sa.Column( + sa.Integer, sa.ForeignKey(TextItem.id), primary_key=True + ) + __mapper_args__ = { + 'polymorphic_identity': u'article' + } + + self.TextItem = TextItem + self.Article = Article + + def test_child_class_using_join_table_inheritance(self): + assert get_tables(self.Article) == [ + self.TextItem.__table__, + self.Article.__table__ + ] + + def test_entity_using_with_polymorphic(self): + assert get_tables(self.TextItem) == [ + self.TextItem.__table__, + self.Article.__table__ + ] + + def test_instrumented_attribute(self): + assert get_tables(self.TextItem.name) == [ + self.TextItem.__table__, + ] + + def test_polymorphic_instrumented_attribute(self): + assert get_tables(self.Article.id) == [ + self.TextItem.__table__, + self.Article.__table__ + ] + + def test_column(self): + assert get_tables(self.Article.__table__.c.id) == [ + self.Article.__table__ + ] + + def test_mapper_entity_with_class(self): + query = self.session.query(self.Article) + assert get_tables(query._entities[0]) == [ + self.TextItem.__table__, self.Article.__table__ + ] + + def test_mapper_entity_with_mapper(self): + query = self.session.query(sa.inspect(self.Article)) + assert get_tables(query._entities[0]) == [ + self.TextItem.__table__, self.Article.__table__ + ] + + def test_column_entity(self): + query = self.session.query(self.Article.id) + assert get_tables(query._entities[0]) == [ + self.TextItem.__table__, self.Article.__table__ + ] diff --git a/python-sqlalchemy-utils/tests/functions/test_get_type.py b/python-sqlalchemy-utils/tests/functions/test_get_type.py new file mode 100644 index 0000000..8990be4 --- /dev/null +++ b/python-sqlalchemy-utils/tests/functions/test_get_type.py @@ -0,0 +1,42 @@ +import sqlalchemy as sa +from sqlalchemy.ext.declarative import declarative_base + +from sqlalchemy_utils import get_type + + +class TestGetType(object): + def setup_method(self, method): + Base = declarative_base() + + class User(Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + + class Article(Base): + __tablename__ = 'article' + id = sa.Column(sa.Integer, primary_key=True) + + author_id = sa.Column(sa.Integer, sa.ForeignKey(User.id)) + author = sa.orm.relationship(User) + + some_property = sa.orm.column_property( + sa.func.coalesce(id, 1) + ) + + self.Article = Article + self.User = User + + def test_instrumented_attribute(self): + assert isinstance(get_type(self.Article.id), sa.Integer) + + def test_column_property(self): + assert isinstance(get_type(self.Article.id.property), sa.Integer) + + def test_column(self): + assert isinstance(get_type(self.Article.__table__.c.id), sa.Integer) + + def test_calculated_column_property(self): + assert isinstance(get_type(self.Article.some_property), sa.Integer) + + def test_relationship_property(self): + assert get_type(self.Article.author) == self.User diff --git a/python-sqlalchemy-utils/tests/functions/test_getdotattr.py b/python-sqlalchemy-utils/tests/functions/test_getdotattr.py new file mode 100644 index 0000000..c768ecb --- /dev/null +++ b/python-sqlalchemy-utils/tests/functions/test_getdotattr.py @@ -0,0 +1,85 @@ +import sqlalchemy as sa + +from sqlalchemy_utils.functions import getdotattr +from tests import TestCase + + +class TestGetDotAttr(TestCase): + def create_models(self): + class Document(self.Base): + __tablename__ = 'document' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + class Section(self.Base): + __tablename__ = 'section' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + document_id = sa.Column( + sa.Integer, sa.ForeignKey(Document.id) + ) + + document = sa.orm.relationship(Document, backref='sections') + + class SubSection(self.Base): + __tablename__ = 'subsection' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + section_id = sa.Column( + sa.Integer, sa.ForeignKey(Section.id) + ) + + section = sa.orm.relationship(Section, backref='subsections') + + class SubSubSection(self.Base): + __tablename__ = 'subsubsection' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + locale = sa.Column(sa.String(10)) + + subsection_id = sa.Column( + sa.Integer, sa.ForeignKey(SubSection.id) + ) + + subsection = sa.orm.relationship( + SubSection, backref='subsubsections' + ) + + self.Document = Document + self.Section = Section + self.SubSection = SubSection + self.SubSubSection = SubSubSection + + def test_simple_objects(self): + document = self.Document(name=u'some document') + section = self.Section(document=document) + subsection = self.SubSection(section=section) + + assert getdotattr( + subsection, + 'section.document.name' + ) == u'some document' + + def test_with_instrumented_lists(self): + document = self.Document(name=u'some document') + section = self.Section(document=document) + subsection = self.SubSection(section=section) + subsubsection = self.SubSubSection(subsection=subsection) + + assert getdotattr(document, 'sections') == [section] + assert getdotattr(document, 'sections.subsections') == [ + subsection + ] + assert getdotattr(document, 'sections.subsections.subsubsections') == [ + subsubsection + ] + + def test_class_paths(self): + assert getdotattr(self.Section, 'document') is self.Section.document + assert ( + getdotattr(self.SubSection, 'section.document') is + self.Section.document + ) + assert getdotattr(self.Section, 'document.name') is self.Document.name diff --git a/python-sqlalchemy-utils/tests/functions/test_has_changes.py b/python-sqlalchemy-utils/tests/functions/test_has_changes.py new file mode 100644 index 0000000..6fc2684 --- /dev/null +++ b/python-sqlalchemy-utils/tests/functions/test_has_changes.py @@ -0,0 +1,47 @@ +import sqlalchemy as sa +from sqlalchemy.ext.declarative import declarative_base + +from sqlalchemy_utils import has_changes + + +class HasChangesTestCase(object): + def setup_method(self, method): + Base = declarative_base() + + class Article(Base): + __tablename__ = 'article_translation' + id = sa.Column(sa.Integer, primary_key=True) + title = sa.Column(sa.String(100)) + + self.Article = Article + + +class TestHasChangesWithStringAttr(HasChangesTestCase): + def test_without_changed_attr(self): + article = self.Article() + assert not has_changes(article, 'title') + + def test_with_changed_attr(self): + article = self.Article(title='Some title') + assert has_changes(article, 'title') + + +class TestHasChangesWithMultipleAttrs(HasChangesTestCase): + def test_without_changed_attr(self): + article = self.Article() + assert not has_changes(article, ['title']) + + def test_with_changed_attr(self): + article = self.Article(title='Some title') + assert has_changes(article, ['title', 'id']) + + +class TestHasChangesWithExclude(HasChangesTestCase): + def test_without_changed_attr(self): + article = self.Article() + assert not has_changes(article, exclude=['id']) + + def test_with_changed_attr(self): + article = self.Article(title='Some title') + assert has_changes(article, exclude=['id']) + assert not has_changes(article, exclude=['title']) diff --git a/python-sqlalchemy-utils/tests/functions/test_has_index.py b/python-sqlalchemy-utils/tests/functions/test_has_index.py new file mode 100644 index 0000000..76906f1 --- /dev/null +++ b/python-sqlalchemy-utils/tests/functions/test_has_index.py @@ -0,0 +1,49 @@ +import sqlalchemy as sa +from pytest import raises +from sqlalchemy.ext.declarative import declarative_base + +from sqlalchemy_utils import has_index + + +class TestHasIndex(object): + def setup_method(self, method): + Base = declarative_base() + + class ArticleTranslation(Base): + __tablename__ = 'article_translation' + id = sa.Column(sa.Integer, primary_key=True) + locale = sa.Column(sa.String(10), primary_key=True) + title = sa.Column(sa.String(100)) + is_published = sa.Column(sa.Boolean, index=True) + is_deleted = sa.Column(sa.Boolean) + is_archived = sa.Column(sa.Boolean) + + __table_args__ = ( + sa.Index('my_index', is_deleted, is_archived), + ) + + self.table = ArticleTranslation.__table__ + + def test_column_that_belongs_to_an_alias(self): + alias = sa.orm.aliased(self.table) + with raises(TypeError): + assert has_index(alias.c.id) + + def test_compound_primary_key(self): + assert has_index(self.table.c.id) + assert not has_index(self.table.c.locale) + + def test_single_column_index(self): + assert has_index(self.table.c.is_published) + + def test_compound_column_index(self): + assert has_index(self.table.c.is_deleted) + assert not has_index(self.table.c.is_archived) + + def test_table_without_primary_key(self): + article = sa.Table( + 'article', + sa.MetaData(), + sa.Column('name', sa.String) + ) + assert not has_index(article.c.name) diff --git a/python-sqlalchemy-utils/tests/functions/test_has_unique_index.py b/python-sqlalchemy-utils/tests/functions/test_has_unique_index.py new file mode 100644 index 0000000..157dfc8 --- /dev/null +++ b/python-sqlalchemy-utils/tests/functions/test_has_unique_index.py @@ -0,0 +1,52 @@ +import sqlalchemy as sa +from pytest import raises +from sqlalchemy.ext.declarative import declarative_base + +from sqlalchemy_utils import has_unique_index + + +class TestHasUniqueIndex(object): + def setup_method(self, method): + Base = declarative_base() + + class Article(Base): + __tablename__ = 'article' + id = sa.Column(sa.Integer, primary_key=True) + + class ArticleTranslation(Base): + __tablename__ = 'article_translation' + id = sa.Column(sa.Integer, primary_key=True) + locale = sa.Column(sa.String(10), primary_key=True) + title = sa.Column(sa.String(100)) + is_published = sa.Column(sa.Boolean, index=True) + is_deleted = sa.Column(sa.Boolean, unique=True) + is_archived = sa.Column(sa.Boolean) + + __table_args__ = ( + sa.Index('my_index', is_archived, is_published, unique=True), + ) + + self.articles = Article.__table__ + self.article_translations = ArticleTranslation.__table__ + + def test_primary_key(self): + assert has_unique_index(self.articles.c.id) + + def test_column_of_aliased_table(self): + alias = sa.orm.aliased(self.articles) + with raises(TypeError): + assert has_unique_index(alias.c.id) + + def test_unique_index(self): + assert has_unique_index(self.article_translations.c.is_deleted) + + def test_compound_primary_key(self): + assert not has_unique_index(self.article_translations.c.id) + assert not has_unique_index(self.article_translations.c.locale) + + def test_single_column_index(self): + assert not has_unique_index(self.article_translations.c.is_published) + + def test_compound_column_unique_index(self): + assert not has_unique_index(self.article_translations.c.is_published) + assert not has_unique_index(self.article_translations.c.is_archived) diff --git a/python-sqlalchemy-utils/tests/functions/test_identity.py b/python-sqlalchemy-utils/tests/functions/test_identity.py new file mode 100644 index 0000000..d1d3d42 --- /dev/null +++ b/python-sqlalchemy-utils/tests/functions/test_identity.py @@ -0,0 +1,39 @@ +import sqlalchemy as sa + +from sqlalchemy_utils.functions import identity +from tests import TestCase + + +class IdentityTestCase(TestCase): + def test_for_transient_class_without_id(self): + assert identity(self.Building()) == (None, ) + + def test_for_transient_class_with_id(self): + building = self.Building(name=u'Some building') + self.session.add(building) + self.session.flush() + + assert identity(building) == (building.id, ) + + def test_identity_for_class(self): + assert identity(self.Building) == (self.Building.id, ) + + +class TestIdentity(IdentityTestCase): + def create_models(self): + class Building(self.Base): + __tablename__ = 'building' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + self.Building = Building + + +class TestIdentityWithColumnAlias(IdentityTestCase): + def create_models(self): + class Building(self.Base): + __tablename__ = 'building' + id = sa.Column('_id', sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + self.Building = Building diff --git a/python-sqlalchemy-utils/tests/functions/test_is_loaded.py b/python-sqlalchemy-utils/tests/functions/test_is_loaded.py new file mode 100644 index 0000000..68d7d24 --- /dev/null +++ b/python-sqlalchemy-utils/tests/functions/test_is_loaded.py @@ -0,0 +1,24 @@ +import sqlalchemy as sa +from sqlalchemy.ext.declarative import declarative_base + +from sqlalchemy_utils import is_loaded + + +class TestIsLoaded(object): + def setup_method(self, method): + Base = declarative_base() + + class Article(Base): + __tablename__ = 'article_translation' + id = sa.Column(sa.Integer, primary_key=True) + title = sa.orm.deferred(sa.Column(sa.String(100))) + + self.Article = Article + + def test_loaded_property(self): + article = self.Article(id=1) + assert is_loaded(article, 'id') + + def test_unloaded_property(self): + article = self.Article(id=4) + assert not is_loaded(article, 'title') diff --git a/python-sqlalchemy-utils/tests/functions/test_json_sql.py b/python-sqlalchemy-utils/tests/functions/test_json_sql.py new file mode 100644 index 0000000..bb3766a --- /dev/null +++ b/python-sqlalchemy-utils/tests/functions/test_json_sql.py @@ -0,0 +1,33 @@ +import pytest +import sqlalchemy as sa + +from sqlalchemy_utils import json_sql +from tests import TestCase + + +class TestJSONSQL(TestCase): + dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' + + @pytest.mark.parametrize( + ('value', 'result'), + ( + (1, 1), + (14.14, 14.14), + ({'a': 2, 'b': 'c'}, {'a': 2, 'b': 'c'}), + ( + {'a': {'b': 'c'}}, + {'a': {'b': 'c'}} + ), + ({}, {}), + ([1, 2], [1, 2]), + ([], []), + ( + [sa.select([sa.text('1')]).label('alias')], + [1] + ) + ) + ) + def test_compiled_scalars(self, value, result): + assert result == ( + self.connection.execute(sa.select([json_sql(value)])).fetchone()[0] + ) diff --git a/python-sqlalchemy-utils/tests/functions/test_make_order_by_deterministic.py b/python-sqlalchemy-utils/tests/functions/test_make_order_by_deterministic.py new file mode 100644 index 0000000..632b1ca --- /dev/null +++ b/python-sqlalchemy-utils/tests/functions/test_make_order_by_deterministic.py @@ -0,0 +1,90 @@ +import sqlalchemy as sa + +from sqlalchemy_utils.functions.sort_query import make_order_by_deterministic +from tests import assert_contains, TestCase + + +class TestMakeOrderByDeterministic(TestCase): + def create_models(self): + class User(self.Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode) + email = sa.Column(sa.Unicode, unique=True) + + email_lower = sa.orm.column_property( + sa.func.lower(name) + ) + + class Article(self.Base): + __tablename__ = 'article' + id = sa.Column(sa.Integer, primary_key=True) + author_id = sa.Column(sa.Integer, sa.ForeignKey('user.id')) + author = sa.orm.relationship(User) + + User.article_count = sa.orm.column_property( + sa.select([sa.func.count()], from_obj=Article) + .where(Article.author_id == User.id) + .label('article_count') + ) + + self.User = User + self.Article = Article + + def test_column_property(self): + query = self.session.query(self.User).order_by(self.User.email_lower) + query = make_order_by_deterministic(query) + assert_contains('lower("user".name), "user".id ASC', query) + + def test_unique_column(self): + query = self.session.query(self.User).order_by(self.User.email) + query = make_order_by_deterministic(query) + + assert str(query).endswith('ORDER BY "user".email') + + def test_non_unique_column(self): + query = self.session.query(self.User).order_by(self.User.name) + query = make_order_by_deterministic(query) + assert_contains('ORDER BY "user".name, "user".id ASC', query) + + def test_descending_order_by(self): + query = self.session.query(self.User).order_by( + sa.desc(self.User.name) + ) + query = make_order_by_deterministic(query) + assert_contains('ORDER BY "user".name DESC, "user".id DESC', query) + + def test_ascending_order_by(self): + query = self.session.query(self.User).order_by( + sa.asc(self.User.name) + ) + query = make_order_by_deterministic(query) + assert_contains('ORDER BY "user".name ASC, "user".id ASC', query) + + def test_string_order_by(self): + query = self.session.query(self.User).order_by('name') + query = make_order_by_deterministic(query) + assert_contains('ORDER BY "user".name, "user".id ASC', query) + + def test_annotated_label(self): + query = self.session.query(self.User).order_by(self.User.article_count) + query = make_order_by_deterministic(query) + assert_contains('article_count, "user".id ASC', query) + + def test_annotated_label_with_descending_order(self): + query = self.session.query(self.User).order_by( + sa.desc(self.User.article_count) + ) + query = make_order_by_deterministic(query) + assert_contains('ORDER BY article_count DESC, "user".id DESC', query) + + def test_query_without_order_by(self): + query = self.session.query(self.User) + query = make_order_by_deterministic(query) + assert 'ORDER BY "user".id' in str(query) + + def test_alias(self): + alias = sa.orm.aliased(self.User.__table__) + query = self.session.query(alias).order_by(alias.c.name) + query = make_order_by_deterministic(query) + assert str(query).endswith('ORDER BY user_1.name, "user".id ASC') diff --git a/python-sqlalchemy-utils/tests/functions/test_merge_references.py b/python-sqlalchemy-utils/tests/functions/test_merge_references.py new file mode 100644 index 0000000..e97f7a2 --- /dev/null +++ b/python-sqlalchemy-utils/tests/functions/test_merge_references.py @@ -0,0 +1,171 @@ +import sqlalchemy as sa + +from sqlalchemy_utils import merge_references +from tests import TestCase + + +class TestMergeReferences(TestCase): + def create_models(self): + class User(self.Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + def __repr__(self): + return 'User(%r)' % self.name + + class BlogPost(self.Base): + __tablename__ = 'blog_post' + id = sa.Column(sa.Integer, primary_key=True) + title = sa.Column(sa.Unicode(255)) + content = sa.Column(sa.UnicodeText) + author_id = sa.Column(sa.Integer, sa.ForeignKey('user.id')) + + author = sa.orm.relationship(User) + + self.User = User + self.BlogPost = BlogPost + + def test_updates_foreign_keys(self): + john = self.User(name=u'John') + jack = self.User(name=u'Jack') + post = self.BlogPost(title=u'Some title', author=john) + post2 = self.BlogPost(title=u'Other title', author=jack) + self.session.add(john) + self.session.add(jack) + self.session.add(post) + self.session.add(post2) + self.session.commit() + merge_references(john, jack) + self.session.commit() + assert post.author == jack + assert post2.author == jack + + def test_object_merging_whenever_possible(self): + john = self.User(name=u'John') + jack = self.User(name=u'Jack') + post = self.BlogPost(title=u'Some title', author=john) + post2 = self.BlogPost(title=u'Other title', author=jack) + self.session.add(john) + self.session.add(jack) + self.session.add(post) + self.session.add(post2) + self.session.commit() + # Load the author for post + assert post.author_id == john.id + merge_references(john, jack) + assert post.author_id == jack.id + assert post2.author_id == jack.id + + +class TestMergeReferencesWithManyToManyAssociations(TestCase): + def create_models(self): + class User(self.Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + def __repr__(self): + return 'User(%r)' % self.name + + team_member = sa.Table( + 'team_member', self.Base.metadata, + sa.Column( + 'user_id', sa.Integer, + sa.ForeignKey('user.id', ondelete='CASCADE'), + primary_key=True + ), + sa.Column( + 'team_id', sa.Integer, + sa.ForeignKey('team.id', ondelete='CASCADE'), + primary_key=True + ) + ) + + class Team(self.Base): + __tablename__ = 'team' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + members = sa.orm.relationship( + User, + secondary=team_member, + backref='teams' + ) + + self.User = User + self.Team = Team + + def test_supports_associations(self): + john = self.User(name=u'John') + jack = self.User(name=u'Jack') + team = self.Team(name=u'Team') + team.members.append(john) + self.session.add(john) + self.session.add(jack) + self.session.commit() + merge_references(john, jack) + assert john not in team.members + assert jack in team.members + + +class TestMergeReferencesWithManyToManyAssociationObjects(TestCase): + def create_models(self): + class Team(self.Base): + __tablename__ = 'team' + id = sa.Column(sa.Integer, autoincrement=True, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + class User(self.Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, autoincrement=True, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + class TeamMember(self.Base): + __tablename__ = 'team_member' + user_id = sa.Column( + sa.Integer, + sa.ForeignKey(User.id, ondelete='CASCADE'), + primary_key=True + ) + team_id = sa.Column( + sa.Integer, + sa.ForeignKey(Team.id, ondelete='CASCADE'), + primary_key=True + ) + role = sa.Column(sa.Unicode(255)) + team = sa.orm.relationship( + Team, + backref=sa.orm.backref( + 'members', + cascade='all, delete-orphan' + ), + primaryjoin=team_id == Team.id, + ) + user = sa.orm.relationship( + User, + backref=sa.orm.backref( + 'memberships', + cascade='all, delete-orphan' + ), + primaryjoin=user_id == User.id, + ) + + self.User = User + self.TeamMember = TeamMember + self.Team = Team + + def test_supports_associations(self): + john = self.User(name=u'John') + jack = self.User(name=u'Jack') + team = self.Team(name=u'Team') + team.members.append(self.TeamMember(user=john)) + self.session.add(john) + self.session.add(jack) + self.session.add(team) + self.session.commit() + merge_references(john, jack) + self.session.commit() + users = [member.user for member in team.members] + assert john not in users + assert jack in users diff --git a/python-sqlalchemy-utils/tests/functions/test_naturally_equivalent.py b/python-sqlalchemy-utils/tests/functions/test_naturally_equivalent.py new file mode 100644 index 0000000..c443e4d --- /dev/null +++ b/python-sqlalchemy-utils/tests/functions/test_naturally_equivalent.py @@ -0,0 +1,14 @@ +from sqlalchemy_utils.functions import naturally_equivalent +from tests import TestCase + + +class TestNaturallyEquivalent(TestCase): + def test_returns_true_when_properties_match(self): + assert naturally_equivalent( + self.User(name=u'someone'), self.User(name=u'someone') + ) + + def test_skips_primary_keys(self): + assert naturally_equivalent( + self.User(id=1, name=u'someone'), self.User(id=2, name=u'someone') + ) diff --git a/python-sqlalchemy-utils/tests/functions/test_non_indexed_foreign_keys.py b/python-sqlalchemy-utils/tests/functions/test_non_indexed_foreign_keys.py new file mode 100644 index 0000000..3c1791e --- /dev/null +++ b/python-sqlalchemy-utils/tests/functions/test_non_indexed_foreign_keys.py @@ -0,0 +1,57 @@ +from itertools import chain + +import sqlalchemy as sa + +from sqlalchemy_utils.functions import non_indexed_foreign_keys +from tests import TestCase + + +class TestFindNonIndexedForeignKeys(TestCase): + def create_models(self): + class User(self.Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, autoincrement=True, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + class Category(self.Base): + __tablename__ = 'category' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + class Article(self.Base): + __tablename__ = 'article' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + author_id = sa.Column( + sa.Integer, sa.ForeignKey(User.id), index=True + ) + category_id = sa.Column(sa.Integer, sa.ForeignKey(Category.id)) + + category = sa.orm.relationship( + Category, + primaryjoin=category_id == Category.id, + backref=sa.orm.backref( + 'articles', + ) + ) + + self.User = User + self.Category = Category + self.Article = Article + + def test_finds_all_non_indexed_fks(self): + fks = non_indexed_foreign_keys(self.Base.metadata, self.engine) + assert ( + 'article' in + fks + ) + column_names = list(chain( + *( + names for names in ( + fk.columns.keys() + for fk in fks['article'] + ) + ) + )) + assert 'category_id' in column_names + assert 'author_id' not in column_names diff --git a/python-sqlalchemy-utils/tests/functions/test_quote.py b/python-sqlalchemy-utils/tests/functions/test_quote.py new file mode 100644 index 0000000..85b7a31 --- /dev/null +++ b/python-sqlalchemy-utils/tests/functions/test_quote.py @@ -0,0 +1,18 @@ +from sqlalchemy.dialects import postgresql + +from sqlalchemy_utils.functions import quote +from tests import TestCase + + +class TestQuote(TestCase): + def test_quote_with_preserved_keyword(self): + assert quote(self.connection, 'order') == '"order"' + assert quote(self.session, 'order') == '"order"' + assert quote(self.engine, 'order') == '"order"' + assert quote(postgresql.dialect(), 'order') == '"order"' + + def test_quote_with_non_preserved_keyword(self): + assert quote(self.connection, 'some_order') == 'some_order' + assert quote(self.session, 'some_order') == 'some_order' + assert quote(self.engine, 'some_order') == 'some_order' + assert quote(postgresql.dialect(), 'some_order') == 'some_order' diff --git a/python-sqlalchemy-utils/tests/functions/test_render.py b/python-sqlalchemy-utils/tests/functions/test_render.py new file mode 100644 index 0000000..2927617 --- /dev/null +++ b/python-sqlalchemy-utils/tests/functions/test_render.py @@ -0,0 +1,58 @@ +import sqlalchemy as sa + +from sqlalchemy_utils.functions import ( + mock_engine, + render_expression, + render_statement +) +from tests import TestCase + + +class TestRender(TestCase): + def create_models(self): + class User(self.Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, autoincrement=True, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + self.User = User + + def test_render_orm_query(self): + query = self.session.query(self.User).filter_by(id=3) + text = render_statement(query) + + assert 'SELECT user.id, user.name' in text + assert 'FROM user' in text + assert 'WHERE user.id = 3' in text + + def test_render_statement(self): + statement = self.User.__table__.select().where(self.User.id == 3) + text = render_statement(statement, bind=self.session.bind) + + assert 'SELECT user.id, user.name' in text + assert 'FROM user' in text + assert 'WHERE user.id = 3' in text + + def test_render_statement_without_mapper(self): + statement = sa.select([sa.text('1')]) + text = render_statement(statement, bind=self.session.bind) + + assert 'SELECT 1' in text + + def test_render_ddl(self): + expression = 'self.User.__table__.create(engine)' + stream = render_expression(expression, self.engine) + + text = stream.getvalue() + + assert 'CREATE TABLE user' in text + assert 'PRIMARY KEY' in text + + def test_render_mock_ddl(self): + with mock_engine('self.engine') as stream: + self.User.__table__.create(self.engine) + + text = stream.getvalue() + + assert 'CREATE TABLE user' in text + assert 'PRIMARY KEY' in text diff --git a/python-sqlalchemy-utils/tests/functions/test_table_name.py b/python-sqlalchemy-utils/tests/functions/test_table_name.py new file mode 100644 index 0000000..6018110 --- /dev/null +++ b/python-sqlalchemy-utils/tests/functions/test_table_name.py @@ -0,0 +1,26 @@ +import sqlalchemy as sa + +from sqlalchemy_utils import table_name +from tests import TestCase + + +class TestTableName(TestCase): + def create_models(self): + class Building(self.Base): + __tablename__ = 'building' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + self.Building = Building + + def test_class(self): + assert table_name(self.Building) == 'building' + del self.Building.__tablename__ + assert table_name(self.Building) == 'building' + + def test_attribute(self): + assert table_name(self.Building.id) == 'building' + assert table_name(self.Building.name) == 'building' + + def test_target(self): + assert table_name(self.Building()) == 'building' diff --git a/python-sqlalchemy-utils/tests/generic_relationship/__init__.py b/python-sqlalchemy-utils/tests/generic_relationship/__init__.py new file mode 100644 index 0000000..9b2de76 --- /dev/null +++ b/python-sqlalchemy-utils/tests/generic_relationship/__init__.py @@ -0,0 +1,109 @@ +from __future__ import unicode_literals + +import six + +from tests import TestCase + + +class GenericRelationshipTestCase(TestCase): + def test_set_as_none(self): + event = self.Event() + event.object = None + assert event.object is None + + def test_set_manual_and_get(self): + user = self.User() + + self.session.add(user) + self.session.commit() + + event = self.Event() + event.object_id = user.id + event.object_type = six.text_type(type(user).__name__) + + assert event.object is None + + self.session.add(event) + self.session.commit() + + assert event.object == user + + def test_set_and_get(self): + user = self.User() + + self.session.add(user) + self.session.commit() + + event = self.Event(object=user) + + assert event.object_id == user.id + assert event.object_type == type(user).__name__ + + self.session.add(event) + self.session.commit() + + assert event.object == user + + def test_compare_instance(self): + user1 = self.User() + user2 = self.User() + + self.session.add_all([user1, user2]) + self.session.commit() + + event = self.Event(object=user1) + + self.session.add(event) + self.session.commit() + + assert event.object == user1 + assert event.object != user2 + + def test_compare_query(self): + user1 = self.User() + user2 = self.User() + + self.session.add_all([user1, user2]) + self.session.commit() + + event = self.Event(object=user1) + + self.session.add(event) + self.session.commit() + + q = self.session.query(self.Event) + assert q.filter_by(object=user1).first() is not None + assert q.filter_by(object=user2).first() is None + assert q.filter(self.Event.object == user2).first() is None + + def test_compare_not_query(self): + user1 = self.User() + user2 = self.User() + + self.session.add_all([user1, user2]) + self.session.commit() + + event = self.Event(object=user1) + + self.session.add(event) + self.session.commit() + + q = self.session.query(self.Event) + assert q.filter(self.Event.object != user2).first() is not None + + def test_compare_type(self): + user1 = self.User() + user2 = self.User() + + self.session.add_all([user1, user2]) + self.session.commit() + + event1 = self.Event(object=user1) + event2 = self.Event(object=user2) + + self.session.add_all([event1, event2]) + self.session.commit() + + statement = self.Event.object.is_type(self.User) + q = self.session.query(self.Event).filter(statement) + assert q.first() is not None diff --git a/python-sqlalchemy-utils/tests/generic_relationship/test_abstract_base_class.py b/python-sqlalchemy-utils/tests/generic_relationship/test_abstract_base_class.py new file mode 100644 index 0000000..709f0db --- /dev/null +++ b/python-sqlalchemy-utils/tests/generic_relationship/test_abstract_base_class.py @@ -0,0 +1,36 @@ +from __future__ import unicode_literals + +import sqlalchemy as sa +from sqlalchemy.ext.declarative import declared_attr + +from sqlalchemy_utils import generic_relationship +from tests.generic_relationship import GenericRelationshipTestCase + + +class TestGenericRelationshipWithAbstractBase(GenericRelationshipTestCase): + def create_models(self): + class Building(self.Base): + __tablename__ = 'building' + id = sa.Column(sa.Integer, primary_key=True) + + class User(self.Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + + class EventBase(self.Base): + __abstract__ = True + + object_type = sa.Column(sa.Unicode(255)) + object_id = sa.Column(sa.Integer, nullable=False) + + @declared_attr + def object(cls): + return generic_relationship('object_type', 'object_id') + + class Event(EventBase): + __tablename__ = 'event' + id = sa.Column(sa.Integer, primary_key=True) + + self.Building = Building + self.User = User + self.Event = Event diff --git a/python-sqlalchemy-utils/tests/generic_relationship/test_column_aliases.py b/python-sqlalchemy-utils/tests/generic_relationship/test_column_aliases.py new file mode 100644 index 0000000..1d93c79 --- /dev/null +++ b/python-sqlalchemy-utils/tests/generic_relationship/test_column_aliases.py @@ -0,0 +1,30 @@ +from __future__ import unicode_literals + +import sqlalchemy as sa + +from sqlalchemy_utils import generic_relationship +from tests.generic_relationship import GenericRelationshipTestCase + + +class TestGenericRelationship(GenericRelationshipTestCase): + def create_models(self): + class Building(self.Base): + __tablename__ = 'building' + id = sa.Column(sa.Integer, primary_key=True) + + class User(self.Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + + class Event(self.Base): + __tablename__ = 'event' + id = sa.Column(sa.Integer, primary_key=True) + + object_type = sa.Column(sa.Unicode(255), name="objectType") + object_id = sa.Column(sa.Integer, nullable=False) + + object = generic_relationship(object_type, object_id) + + self.Building = Building + self.User = User + self.Event = Event diff --git a/python-sqlalchemy-utils/tests/generic_relationship/test_composite_keys.py b/python-sqlalchemy-utils/tests/generic_relationship/test_composite_keys.py new file mode 100644 index 0000000..d730d09 --- /dev/null +++ b/python-sqlalchemy-utils/tests/generic_relationship/test_composite_keys.py @@ -0,0 +1,66 @@ +from __future__ import unicode_literals + +import six +import sqlalchemy as sa + +from sqlalchemy_utils import generic_relationship +from tests.generic_relationship import GenericRelationshipTestCase + + +class TestGenericRelationship(GenericRelationshipTestCase): + index = 1 + + def create_models(self): + class Building(self.Base): + __tablename__ = 'building' + id = sa.Column(sa.Integer, primary_key=True) + code = sa.Column(sa.Integer, primary_key=True) + + def __init__(obj_self): + self.index += 1 + obj_self.id = self.index + obj_self.code = self.index + + class User(self.Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + code = sa.Column(sa.Integer, primary_key=True) + + def __init__(obj_self): + self.index += 1 + obj_self.id = self.index + obj_self.code = self.index + + class Event(self.Base): + __tablename__ = 'event' + id = sa.Column(sa.Integer, primary_key=True) + + object_type = sa.Column(sa.Unicode(255)) + object_id = sa.Column(sa.Integer, nullable=False) + object_code = sa.Column(sa.Integer, nullable=False) + + object = generic_relationship( + object_type, (object_id, object_code) + ) + + self.Building = Building + self.User = User + self.Event = Event + + def test_set_manual_and_get(self): + user = self.User() + + self.session.add(user) + self.session.commit() + + event = self.Event() + event.object_id = user.id + event.object_type = six.text_type(type(user).__name__) + event.object_code = user.code + + assert event.object is None + + self.session.add(event) + self.session.commit() + + assert event.object == user diff --git a/python-sqlalchemy-utils/tests/generic_relationship/test_hybrid_properties.py b/python-sqlalchemy-utils/tests/generic_relationship/test_hybrid_properties.py new file mode 100644 index 0000000..ce4575f --- /dev/null +++ b/python-sqlalchemy-utils/tests/generic_relationship/test_hybrid_properties.py @@ -0,0 +1,68 @@ +from __future__ import unicode_literals + +import six +import sqlalchemy as sa +from sqlalchemy.ext.hybrid import hybrid_property + +from sqlalchemy_utils import generic_relationship +from tests import TestCase + + +class TestGenericRelationship(TestCase): + def create_models(self): + class User(self.Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + + class UserHistory(self.Base): + __tablename__ = 'user_history' + id = sa.Column(sa.Integer, primary_key=True) + + transaction_id = sa.Column(sa.Integer, primary_key=True) + + class Event(self.Base): + __tablename__ = 'event' + id = sa.Column(sa.Integer, primary_key=True) + + transaction_id = sa.Column(sa.Integer) + + object_type = sa.Column(sa.Unicode(255)) + object_id = sa.Column(sa.Integer, nullable=False) + + object = generic_relationship( + object_type, object_id + ) + + @hybrid_property + def object_version_type(self): + return self.object_type + 'History' + + @object_version_type.expression + def object_version_type(cls): + return sa.func.concat(cls.object_type, 'History') + + object_version = generic_relationship( + object_version_type, (object_id, transaction_id) + ) + + self.User = User + self.UserHistory = UserHistory + self.Event = Event + + def test_set_manual_and_get(self): + user = self.User(id=1) + history = self.UserHistory(id=1, transaction_id=1) + self.session.add(user) + self.session.add(history) + self.session.commit() + + event = self.Event(transaction_id=1) + event.object_id = user.id + event.object_type = six.text_type(type(user).__name__) + assert event.object is None + + self.session.add(event) + self.session.commit() + + assert event.object == user + assert event.object_version == history diff --git a/python-sqlalchemy-utils/tests/generic_relationship/test_single_table_inheritance.py b/python-sqlalchemy-utils/tests/generic_relationship/test_single_table_inheritance.py new file mode 100644 index 0000000..1013717 --- /dev/null +++ b/python-sqlalchemy-utils/tests/generic_relationship/test_single_table_inheritance.py @@ -0,0 +1,164 @@ +from __future__ import unicode_literals + +import six +import sqlalchemy as sa + +from sqlalchemy_utils import generic_relationship +from tests import TestCase + + +class TestGenericRelationship(TestCase): + def create_models(self): + class Employee(self.Base): + __tablename__ = 'employee' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.String(50)) + type = sa.Column(sa.String(20)) + + __mapper_args__ = { + 'polymorphic_on': type, + 'polymorphic_identity': 'employee' + } + + class Manager(Employee): + __mapper_args__ = { + 'polymorphic_identity': 'manager' + } + + class Engineer(Employee): + __mapper_args__ = { + 'polymorphic_identity': 'engineer' + } + + class Event(self.Base): + __tablename__ = 'event' + id = sa.Column(sa.Integer, primary_key=True) + + object_type = sa.Column(sa.Unicode(255)) + object_id = sa.Column(sa.Integer, nullable=False) + + object = generic_relationship(object_type, object_id) + + self.Employee = Employee + self.Manager = Manager + self.Engineer = Engineer + self.Event = Event + + def test_set_as_none(self): + event = self.Event() + event.object = None + assert event.object is None + + def test_set_manual_and_get(self): + manager = self.Manager() + + self.session.add(manager) + self.session.commit() + + event = self.Event() + event.object_id = manager.id + event.object_type = six.text_type(type(manager).__name__) + + assert event.object is None + + self.session.add(event) + self.session.commit() + + assert event.object == manager + + def test_set_and_get(self): + manager = self.Manager() + + self.session.add(manager) + self.session.commit() + + event = self.Event(object=manager) + + assert event.object_id == manager.id + assert event.object_type == type(manager).__name__ + + self.session.add(event) + self.session.commit() + + assert event.object == manager + + def test_compare_instance(self): + manager1 = self.Manager() + manager2 = self.Manager() + + self.session.add_all([manager1, manager2]) + self.session.commit() + + event = self.Event(object=manager1) + + self.session.add(event) + self.session.commit() + + assert event.object == manager1 + assert event.object != manager2 + + def test_compare_query(self): + manager1 = self.Manager() + manager2 = self.Manager() + + self.session.add_all([manager1, manager2]) + self.session.commit() + + event = self.Event(object=manager1) + + self.session.add(event) + self.session.commit() + + q = self.session.query(self.Event) + assert q.filter_by(object=manager1).first() is not None + assert q.filter_by(object=manager2).first() is None + assert q.filter(self.Event.object == manager2).first() is None + + def test_compare_not_query(self): + manager1 = self.Manager() + manager2 = self.Manager() + + self.session.add_all([manager1, manager2]) + self.session.commit() + + event = self.Event(object=manager1) + + self.session.add(event) + self.session.commit() + + q = self.session.query(self.Event) + assert q.filter(self.Event.object != manager2).first() is not None + + def test_compare_type(self): + manager1 = self.Manager() + manager2 = self.Manager() + + self.session.add_all([manager1, manager2]) + self.session.commit() + + event1 = self.Event(object=manager1) + event2 = self.Event(object=manager2) + + self.session.add_all([event1, event2]) + self.session.commit() + + statement = self.Event.object.is_type(self.Manager) + q = self.session.query(self.Event).filter(statement) + assert q.first() is not None + + def test_compare_super_type(self): + manager1 = self.Manager() + manager2 = self.Manager() + + self.session.add_all([manager1, manager2]) + self.session.commit() + + event1 = self.Event(object=manager1) + event2 = self.Event(object=manager2) + + self.session.add_all([event1, event2]) + self.session.commit() + + statement = self.Event.object.is_type(self.Employee) + q = self.session.query(self.Event).filter(statement) + assert q.first() is not None diff --git a/python-sqlalchemy-utils/tests/mixins.py b/python-sqlalchemy-utils/tests/mixins.py new file mode 100644 index 0000000..1224024 --- /dev/null +++ b/python-sqlalchemy-utils/tests/mixins.py @@ -0,0 +1,191 @@ +import sqlalchemy as sa + + +class ThreeLevelDeepOneToOne(object): + def create_models(self): + class Catalog(self.Base): + __tablename__ = 'catalog' + id = sa.Column('_id', sa.Integer, primary_key=True) + category = sa.orm.relationship( + 'Category', + uselist=False, + backref='catalog' + ) + + class Category(self.Base): + __tablename__ = 'category' + id = sa.Column('_id', sa.Integer, primary_key=True) + catalog_id = sa.Column( + '_catalog_id', + sa.Integer, + sa.ForeignKey('catalog._id') + ) + + sub_category = sa.orm.relationship( + 'SubCategory', + uselist=False, + backref='category' + ) + + class SubCategory(self.Base): + __tablename__ = 'sub_category' + id = sa.Column('_id', sa.Integer, primary_key=True) + category_id = sa.Column( + '_category_id', + sa.Integer, + sa.ForeignKey('category._id') + ) + product = sa.orm.relationship( + 'Product', + uselist=False, + backref='sub_category' + ) + + class Product(self.Base): + __tablename__ = 'product' + id = sa.Column('_id', sa.Integer, primary_key=True) + price = sa.Column(sa.Integer) + + sub_category_id = sa.Column( + '_sub_category_id', + sa.Integer, + sa.ForeignKey('sub_category._id') + ) + + self.Catalog = Catalog + self.Category = Category + self.SubCategory = SubCategory + self.Product = Product + + +class ThreeLevelDeepOneToMany(object): + def create_models(self): + class Catalog(self.Base): + __tablename__ = 'catalog' + id = sa.Column('_id', sa.Integer, primary_key=True) + + categories = sa.orm.relationship('Category', backref='catalog') + + class Category(self.Base): + __tablename__ = 'category' + id = sa.Column('_id', sa.Integer, primary_key=True) + catalog_id = sa.Column( + '_catalog_id', + sa.Integer, + sa.ForeignKey('catalog._id') + ) + + sub_categories = sa.orm.relationship( + 'SubCategory', backref='category' + ) + + class SubCategory(self.Base): + __tablename__ = 'sub_category' + id = sa.Column('_id', sa.Integer, primary_key=True) + category_id = sa.Column( + '_category_id', + sa.Integer, + sa.ForeignKey('category._id') + ) + products = sa.orm.relationship( + 'Product', + backref='sub_category' + ) + + class Product(self.Base): + __tablename__ = 'product' + id = sa.Column('_id', sa.Integer, primary_key=True) + price = sa.Column(sa.Numeric) + + sub_category_id = sa.Column( + '_sub_category_id', + sa.Integer, + sa.ForeignKey('sub_category._id') + ) + + def __repr__(self): + return '' % self.id + + self.Catalog = Catalog + self.Category = Category + self.SubCategory = SubCategory + self.Product = Product + + +class ThreeLevelDeepManyToMany(object): + def create_models(self): + catalog_category = sa.Table( + 'catalog_category', + self.Base.metadata, + sa.Column('catalog_id', sa.Integer, sa.ForeignKey('catalog._id')), + sa.Column('category_id', sa.Integer, sa.ForeignKey('category._id')) + ) + + category_subcategory = sa.Table( + 'category_subcategory', + self.Base.metadata, + sa.Column( + 'category_id', + sa.Integer, + sa.ForeignKey('category._id') + ), + sa.Column( + 'subcategory_id', + sa.Integer, + sa.ForeignKey('sub_category._id') + ) + ) + + subcategory_product = sa.Table( + 'subcategory_product', + self.Base.metadata, + sa.Column( + 'subcategory_id', + sa.Integer, + sa.ForeignKey('sub_category._id') + ), + sa.Column( + 'product_id', + sa.Integer, + sa.ForeignKey('product._id') + ) + ) + + class Catalog(self.Base): + __tablename__ = 'catalog' + id = sa.Column('_id', sa.Integer, primary_key=True) + + categories = sa.orm.relationship( + 'Category', + backref='catalogs', + secondary=catalog_category + ) + + class Category(self.Base): + __tablename__ = 'category' + id = sa.Column('_id', sa.Integer, primary_key=True) + + sub_categories = sa.orm.relationship( + 'SubCategory', + backref='categories', + secondary=category_subcategory + ) + + class SubCategory(self.Base): + __tablename__ = 'sub_category' + id = sa.Column('_id', sa.Integer, primary_key=True) + products = sa.orm.relationship( + 'Product', + backref='sub_categories', + secondary=subcategory_product + ) + + class Product(self.Base): + __tablename__ = 'product' + id = sa.Column('_id', sa.Integer, primary_key=True) + price = sa.Column(sa.Numeric) + + self.Catalog = Catalog + self.Category = Category + self.SubCategory = SubCategory + self.Product = Product diff --git a/python-sqlalchemy-utils/tests/observes/__init__.py b/python-sqlalchemy-utils/tests/observes/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/python-sqlalchemy-utils/tests/observes/test_column_property.py b/python-sqlalchemy-utils/tests/observes/test_column_property.py new file mode 100644 index 0000000..8339f4b --- /dev/null +++ b/python-sqlalchemy-utils/tests/observes/test_column_property.py @@ -0,0 +1,26 @@ +import sqlalchemy as sa + +from sqlalchemy_utils.observer import observes +from tests import TestCase + + +class TestObservesForColumn(TestCase): + dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' + + def create_models(self): + class Product(self.Base): + __tablename__ = 'product' + id = sa.Column(sa.Integer, primary_key=True) + price = sa.Column(sa.Integer) + + @observes('price') + def product_price_observer(self, price): + self.price = price * 2 + + self.Product = Product + + def test_simple_insert(self): + product = self.Product(price=100) + self.session.add(product) + self.session.flush() + assert product.price == 200 diff --git a/python-sqlalchemy-utils/tests/observes/test_m2m_m2m_m2m.py b/python-sqlalchemy-utils/tests/observes/test_m2m_m2m_m2m.py new file mode 100644 index 0000000..3b416f2 --- /dev/null +++ b/python-sqlalchemy-utils/tests/observes/test_m2m_m2m_m2m.py @@ -0,0 +1,137 @@ +import sqlalchemy as sa + +from sqlalchemy_utils.observer import observes +from tests import TestCase + + +class TestObservesForManyToManyToManyToMany(TestCase): + dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' + + def create_models(self): + catalog_category = sa.Table( + 'catalog_category', + self.Base.metadata, + sa.Column('catalog_id', sa.Integer, sa.ForeignKey('catalog.id')), + sa.Column('category_id', sa.Integer, sa.ForeignKey('category.id')) + ) + + category_subcategory = sa.Table( + 'category_subcategory', + self.Base.metadata, + sa.Column( + 'category_id', + sa.Integer, + sa.ForeignKey('category.id') + ), + sa.Column( + 'subcategory_id', + sa.Integer, + sa.ForeignKey('sub_category.id') + ) + ) + + subcategory_product = sa.Table( + 'subcategory_product', + self.Base.metadata, + sa.Column( + 'subcategory_id', + sa.Integer, + sa.ForeignKey('sub_category.id') + ), + sa.Column( + 'product_id', + sa.Integer, + sa.ForeignKey('product.id') + ) + ) + + class Catalog(self.Base): + __tablename__ = 'catalog' + id = sa.Column(sa.Integer, primary_key=True) + product_count = sa.Column(sa.Integer, default=0) + + @observes('categories.sub_categories.products') + def product_observer(self, products): + self.product_count = len(products) + + categories = sa.orm.relationship( + 'Category', + backref='catalogs', + secondary=catalog_category + ) + + class Category(self.Base): + __tablename__ = 'category' + id = sa.Column(sa.Integer, primary_key=True) + + sub_categories = sa.orm.relationship( + 'SubCategory', + backref='categories', + secondary=category_subcategory + ) + + class SubCategory(self.Base): + __tablename__ = 'sub_category' + id = sa.Column(sa.Integer, primary_key=True) + products = sa.orm.relationship( + 'Product', + backref='sub_categories', + secondary=subcategory_product + ) + + class Product(self.Base): + __tablename__ = 'product' + id = sa.Column(sa.Integer, primary_key=True) + price = sa.Column(sa.Numeric) + + self.Catalog = Catalog + self.Category = Category + self.SubCategory = SubCategory + self.Product = Product + + def create_catalog(self): + sub_category = self.SubCategory(products=[self.Product()]) + category = self.Category(sub_categories=[sub_category]) + catalog = self.Catalog(categories=[category]) + self.session.add(catalog) + self.session.flush() + return catalog + + def test_simple_insert(self): + catalog = self.create_catalog() + assert catalog.product_count == 1 + + def test_add_leaf_object(self): + catalog = self.create_catalog() + product = self.Product() + catalog.categories[0].sub_categories[0].products.append(product) + self.session.flush() + assert catalog.product_count == 2 + + def test_remove_leaf_object(self): + catalog = self.create_catalog() + product = self.Product() + catalog.categories[0].sub_categories[0].products.append(product) + self.session.flush() + self.session.delete(product) + self.session.flush() + assert catalog.product_count == 1 + + def test_delete_intermediate_object(self): + catalog = self.create_catalog() + self.session.delete(catalog.categories[0].sub_categories[0]) + self.session.commit() + assert catalog.product_count == 0 + + def test_gathered_objects_are_distinct(self): + catalog = self.Catalog() + category = self.Category(catalogs=[catalog]) + product = self.Product() + category.sub_categories.append( + self.SubCategory(products=[product]) + ) + self.session.add( + self.SubCategory(categories=[category], products=[product]) + ) + self.session.commit() + assert catalog.product_count == 1 diff --git a/python-sqlalchemy-utils/tests/observes/test_o2m_o2m_o2m.py b/python-sqlalchemy-utils/tests/observes/test_o2m_o2m_o2m.py new file mode 100644 index 0000000..a11b378 --- /dev/null +++ b/python-sqlalchemy-utils/tests/observes/test_o2m_o2m_o2m.py @@ -0,0 +1,107 @@ +import sqlalchemy as sa + +from sqlalchemy_utils.observer import observes +from tests import TestCase + + +class TestObservesFor3LevelDeepOneToMany(TestCase): + dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' + + def create_models(self): + class Catalog(self.Base): + __tablename__ = 'catalog' + id = sa.Column(sa.Integer, primary_key=True) + product_count = sa.Column(sa.Integer, default=0) + + @observes('categories.sub_categories.products') + def product_observer(self, products): + self.product_count = len(products) + + categories = sa.orm.relationship('Category', backref='catalog') + + class Category(self.Base): + __tablename__ = 'category' + id = sa.Column(sa.Integer, primary_key=True) + catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id')) + + sub_categories = sa.orm.relationship( + 'SubCategory', backref='category' + ) + + class SubCategory(self.Base): + __tablename__ = 'sub_category' + id = sa.Column(sa.Integer, primary_key=True) + category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id')) + products = sa.orm.relationship( + 'Product', + backref='sub_category' + ) + + class Product(self.Base): + __tablename__ = 'product' + id = sa.Column(sa.Integer, primary_key=True) + price = sa.Column(sa.Numeric) + + sub_category_id = sa.Column( + sa.Integer, sa.ForeignKey('sub_category.id') + ) + + def __repr__(self): + return '' % self.id + + self.Catalog = Catalog + self.Category = Category + self.SubCategory = SubCategory + self.Product = Product + + def create_catalog(self): + sub_category = self.SubCategory(products=[self.Product()]) + category = self.Category(sub_categories=[sub_category]) + catalog = self.Catalog(categories=[category]) + self.session.add(catalog) + self.session.commit() + return catalog + + def test_simple_insert(self): + catalog = self.create_catalog() + assert catalog.product_count == 1 + + def test_add_leaf_object(self): + catalog = self.create_catalog() + product = self.Product() + catalog.categories[0].sub_categories[0].products.append(product) + self.session.flush() + assert catalog.product_count == 2 + + def test_remove_leaf_object(self): + catalog = self.create_catalog() + product = self.Product() + catalog.categories[0].sub_categories[0].products.append(product) + self.session.flush() + self.session.delete(product) + self.session.commit() + assert catalog.product_count == 1 + self.session.delete( + catalog.categories[0].sub_categories[0].products[0] + ) + self.session.commit() + assert catalog.product_count == 0 + + def test_delete_intermediate_object(self): + catalog = self.create_catalog() + self.session.delete(catalog.categories[0].sub_categories[0]) + self.session.commit() + assert catalog.product_count == 0 + + def test_gathered_objects_are_distinct(self): + catalog = self.Catalog() + category = self.Category(catalog=catalog) + product = self.Product() + category.sub_categories.append( + self.SubCategory(products=[product]) + ) + self.session.add( + self.SubCategory(category=category, products=[product]) + ) + self.session.commit() + assert catalog.product_count == 1 diff --git a/python-sqlalchemy-utils/tests/observes/test_o2m_o2o_o2m.py b/python-sqlalchemy-utils/tests/observes/test_o2m_o2o_o2m.py new file mode 100644 index 0000000..2299280 --- /dev/null +++ b/python-sqlalchemy-utils/tests/observes/test_o2m_o2o_o2m.py @@ -0,0 +1,96 @@ +import sqlalchemy as sa + +from sqlalchemy_utils.observer import observes +from tests import TestCase + + +class TestObservesForOneToManyToOneToMany(TestCase): + dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' + + def create_models(self): + class Catalog(self.Base): + __tablename__ = 'catalog' + id = sa.Column(sa.Integer, primary_key=True) + product_count = sa.Column(sa.Integer, default=0) + + @observes('categories.sub_category.products') + def product_observer(self, products): + self.product_count = len(products) + + categories = sa.orm.relationship('Category', backref='catalog') + + class Category(self.Base): + __tablename__ = 'category' + id = sa.Column(sa.Integer, primary_key=True) + catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id')) + + sub_category = sa.orm.relationship( + 'SubCategory', + uselist=False, + backref='category' + ) + + class SubCategory(self.Base): + __tablename__ = 'sub_category' + id = sa.Column(sa.Integer, primary_key=True) + category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id')) + products = sa.orm.relationship('Product', backref='sub_category') + + class Product(self.Base): + __tablename__ = 'product' + id = sa.Column(sa.Integer, primary_key=True) + price = sa.Column(sa.Numeric) + + sub_category_id = sa.Column( + sa.Integer, sa.ForeignKey('sub_category.id') + ) + + self.Catalog = Catalog + self.Category = Category + self.SubCategory = SubCategory + self.Product = Product + + def create_catalog(self): + sub_category = self.SubCategory(products=[self.Product()]) + category = self.Category(sub_category=sub_category) + catalog = self.Catalog(categories=[category]) + self.session.add(catalog) + self.session.flush() + return catalog + + def test_simple_insert(self): + catalog = self.create_catalog() + assert catalog.product_count == 1 + + def test_add_leaf_object(self): + catalog = self.create_catalog() + product = self.Product() + catalog.categories[0].sub_category.products.append(product) + self.session.flush() + assert catalog.product_count == 2 + + def test_remove_leaf_object(self): + catalog = self.create_catalog() + product = self.Product() + catalog.categories[0].sub_category.products.append(product) + self.session.flush() + self.session.delete(product) + self.session.flush() + assert catalog.product_count == 1 + + def test_delete_intermediate_object(self): + catalog = self.create_catalog() + self.session.delete(catalog.categories[0].sub_category) + self.session.commit() + assert catalog.product_count == 0 + + def test_gathered_objects_are_distinct(self): + catalog = self.Catalog() + category = self.Category(catalog=catalog) + product = self.Product() + category.sub_category = self.SubCategory(products=[product]) + self.session.add( + self.Category(catalog=catalog, sub_category=category.sub_category) + ) + self.session.commit() + assert catalog.product_count == 1 diff --git a/python-sqlalchemy-utils/tests/observes/test_o2o_o2o_o2o.py b/python-sqlalchemy-utils/tests/observes/test_o2o_o2o_o2o.py new file mode 100644 index 0000000..00cfca8 --- /dev/null +++ b/python-sqlalchemy-utils/tests/observes/test_o2o_o2o_o2o.py @@ -0,0 +1,84 @@ +import sqlalchemy as sa + +from sqlalchemy_utils.observer import observes +from tests import TestCase + + +class TestObservesForOneToOneToOneToOne(TestCase): + dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' + + def create_models(self): + class Catalog(self.Base): + __tablename__ = 'catalog' + id = sa.Column(sa.Integer, primary_key=True) + product_price = sa.Column(sa.Integer) + + @observes('category.sub_category.product') + def product_observer(self, product): + self.product_price = product.price if product else None + + category = sa.orm.relationship( + 'Category', + uselist=False, + backref='catalog' + ) + + class Category(self.Base): + __tablename__ = 'category' + id = sa.Column(sa.Integer, primary_key=True) + catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id')) + + sub_category = sa.orm.relationship( + 'SubCategory', + uselist=False, + backref='category' + ) + + class SubCategory(self.Base): + __tablename__ = 'sub_category' + id = sa.Column(sa.Integer, primary_key=True) + category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id')) + product = sa.orm.relationship( + 'Product', + uselist=False, + backref='sub_category' + ) + + class Product(self.Base): + __tablename__ = 'product' + id = sa.Column(sa.Integer, primary_key=True) + price = sa.Column(sa.Integer) + + sub_category_id = sa.Column( + sa.Integer, sa.ForeignKey('sub_category.id') + ) + + self.Catalog = Catalog + self.Category = Category + self.SubCategory = SubCategory + self.Product = Product + + def create_catalog(self): + sub_category = self.SubCategory(product=self.Product(price=123)) + category = self.Category(sub_category=sub_category) + catalog = self.Catalog(category=category) + self.session.add(catalog) + self.session.flush() + return catalog + + def test_simple_insert(self): + catalog = self.create_catalog() + assert catalog.product_price == 123 + + def test_replace_leaf_object(self): + catalog = self.create_catalog() + product = self.Product(price=44) + catalog.category.sub_category.product = product + self.session.flush() + assert catalog.product_price == 44 + + def test_delete_leaf_object(self): + catalog = self.create_catalog() + self.session.delete(catalog.category.sub_category.product) + self.session.flush() + assert catalog.product_price is None diff --git a/python-sqlalchemy-utils/tests/primitives/__init__.py b/python-sqlalchemy-utils/tests/primitives/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/python-sqlalchemy-utils/tests/primitives/test_country.py b/python-sqlalchemy-utils/tests/primitives/test_country.py new file mode 100644 index 0000000..751e876 --- /dev/null +++ b/python-sqlalchemy-utils/tests/primitives/test_country.py @@ -0,0 +1,67 @@ +import six +from pytest import mark, raises + +from sqlalchemy_utils import Country, i18n + + +@mark.skipif('i18n.babel is None') +class TestCountry(object): + def setup_method(self, method): + i18n.get_locale = lambda: i18n.babel.Locale('en') + + def test_init(self): + assert Country(u'FI') == Country(Country(u'FI')) + + def test_constructor_with_wrong_type(self): + with raises(TypeError) as e: + Country(None) + assert str(e.value) == ( + "Country() argument must be a string or a country, not 'NoneType'" + ) + + def test_constructor_with_invalid_code(self): + with raises(ValueError) as e: + Country('SomeUnknownCode') + assert str(e.value) == ( + 'Could not convert string to country code: SomeUnknownCode' + ) + + @mark.parametrize( + 'code', + ( + 'FI', + 'US', + ) + ) + def test_validate_with_valid_codes(self, code): + Country.validate(code) + + def test_validate_with_invalid_code(self): + with raises(ValueError) as e: + Country.validate('SomeUnknownCode') + assert str(e.value) == ( + 'Could not convert string to country code: SomeUnknownCode' + ) + + def test_equality_operator(self): + assert Country(u'FI') == u'FI' + assert u'FI' == Country(u'FI') + assert Country(u'FI') == Country(u'FI') + + def test_non_equality_operator(self): + assert Country(u'FI') != u'sv' + assert not (Country(u'FI') != u'FI') + + def test_hash(self): + return hash(Country('FI')) == hash('FI') + + def test_repr(self): + return repr(Country('FI')) == "Country('FI')" + + def test_unicode(self): + country = Country('FI') + assert six.text_type(country) == u'Finland' + + def test_str(self): + country = Country('FI') + assert str(country) == 'Finland' diff --git a/python-sqlalchemy-utils/tests/primitives/test_currency.py b/python-sqlalchemy-utils/tests/primitives/test_currency.py new file mode 100644 index 0000000..a6c4876 --- /dev/null +++ b/python-sqlalchemy-utils/tests/primitives/test_currency.py @@ -0,0 +1,66 @@ +# -*- coding: utf-8 -*- +import six +from pytest import mark, raises + +from sqlalchemy_utils import Currency, i18n + + +@mark.skipif('i18n.babel is None') +class TestCurrency(object): + def setup_method(self, method): + i18n.get_locale = lambda: i18n.babel.Locale('en') + + def test_init(self): + assert Currency('USD') == Currency(Currency('USD')) + + def test_hashability(self): + assert len(set([Currency('USD'), Currency('USD')])) == 1 + + def test_invalid_currency_code(self): + with raises(ValueError): + Currency('Unknown code') + + def test_invalid_currency_code_type(self): + with raises(TypeError): + Currency(None) + + @mark.parametrize( + ('code', 'name'), + ( + ('USD', 'US Dollar'), + ('EUR', 'Euro') + ) + ) + def test_name_property(self, code, name): + assert Currency(code).name == name + + @mark.parametrize( + ('code', 'symbol'), + ( + ('USD', u'$'), + ('EUR', u'€') + ) + ) + def test_symbol_property(self, code, symbol): + assert Currency(code).symbol == symbol + + def test_equality_operator(self): + assert Currency('USD') == 'USD' + assert 'USD' == Currency('USD') + assert Currency('USD') == Currency('USD') + + def test_non_equality_operator(self): + assert Currency('USD') != 'EUR' + assert not (Currency('USD') != 'USD') + + def test_unicode(self): + currency = Currency('USD') + assert six.text_type(currency) == u'USD' + + def test_str(self): + currency = Currency('USD') + assert str(currency) == 'USD' + + def test_representation(self): + currency = Currency('USD') + assert repr(currency) == "Currency('USD')" diff --git a/python-sqlalchemy-utils/tests/primitives/test_weekdays.py b/python-sqlalchemy-utils/tests/primitives/test_weekdays.py new file mode 100644 index 0000000..3e1b246 --- /dev/null +++ b/python-sqlalchemy-utils/tests/primitives/test_weekdays.py @@ -0,0 +1,165 @@ +import pytest +import six +from flexmock import flexmock + +from sqlalchemy_utils import i18n +from sqlalchemy_utils.primitives import WeekDay, WeekDays + + +@pytest.mark.skipif('i18n.babel is None') +class TestWeekDay(object): + def setup_method(self, method): + i18n.get_locale = lambda: i18n.babel.Locale('fi') + + def test_constructor_with_valid_index(self): + day = WeekDay(1) + assert day.index == 1 + + @pytest.mark.parametrize('index', [-1, 7]) + def test_constructor_with_invalid_index(self, index): + with pytest.raises(ValueError): + WeekDay(index) + + def test_equality_with_equal_week_day(self): + day = WeekDay(1) + day2 = WeekDay(1) + assert day == day2 + + def test_equality_with_unequal_week_day(self): + day = WeekDay(1) + day2 = WeekDay(2) + assert day != day2 + + def test_equality_with_unsupported_comparison(self): + day = WeekDay(1) + assert day != 'foobar' + + def test_hash_is_equal_to_index_hash(self): + day = WeekDay(1) + assert hash(day) == hash(day.index) + + def test_representation(self): + day = WeekDay(1) + assert repr(day) == "WeekDay(1)" + + @pytest.mark.parametrize( + ('index', 'first_week_day', 'position'), + [ + (0, 0, 0), + (1, 0, 1), + (6, 0, 6), + (0, 6, 1), + (1, 6, 2), + (6, 6, 0), + ] + ) + def test_position(self, index, first_week_day, position): + i18n.get_locale = flexmock(first_week_day=first_week_day) + day = WeekDay(index) + assert day.position == position + + def test_get_name_returns_localized_week_day_name(self): + day = WeekDay(0) + assert day.get_name() == u'maanantaina' + + def test_override_get_locale_as_class_method(self): + day = WeekDay(0) + assert day.get_name() == u'maanantaina' + + def test_name_delegates_to_get_name(self): + day = WeekDay(0) + flexmock(day).should_receive('get_name').and_return(u'maanantaina') + assert day.name == u'maanantaina' + + def test_unicode(self): + day = WeekDay(0) + flexmock(day).should_receive('name').and_return(u'maanantaina') + assert six.text_type(day) == u'maanantaina' + + def test_str(self): + day = WeekDay(0) + flexmock(day).should_receive('name').and_return(u'maanantaina') + assert str(day) == 'maanantaina' + + +@pytest.mark.skipif('i18n.babel is None') +class TestWeekDays(object): + def test_constructor_with_valid_bit_string(self): + days = WeekDays('1000100') + assert days._days == set([WeekDay(0), WeekDay(4)]) + + @pytest.mark.parametrize( + 'bit_string', + [ + '000000', # too short + '00000000', # too long + ] + ) + def test_constructor_with_bit_string_of_invalid_length(self, bit_string): + with pytest.raises(ValueError): + WeekDays(bit_string) + + def test_constructor_with_bit_string_containing_invalid_characters(self): + with pytest.raises(ValueError): + WeekDays('foobarz') + + def test_constructor_with_another_week_days_object(self): + days = WeekDays('0000000') + another_days = WeekDays(days) + assert days._days == another_days._days + + def test_representation(self): + days = WeekDays('0000000') + assert repr(days) == "WeekDays('0000000')" + + @pytest.mark.parametrize( + 'bit_string', + [ + '0000000', + '1000000', + '0000001', + '0101000', + '1111111', + ] + ) + def test_as_bit_string(self, bit_string): + days = WeekDays(bit_string) + assert days.as_bit_string() == bit_string + + def test_equality_with_equal_week_days_object(self): + days = WeekDays('0001000') + days2 = WeekDays('0001000') + assert days == days2 + + def test_equality_with_unequal_week_days_object(self): + days = WeekDays('0001000') + days2 = WeekDays('1000000') + assert days != days2 + + def test_equality_with_equal_bit_string(self): + days = WeekDays('0001000') + assert days == '0001000' + + def test_equality_with_unequal_bit_string(self): + days = WeekDays('0001000') + assert days != '0101000' + + def test_equality_with_unsupported_comparison(self): + days = WeekDays('0001000') + assert days != 0 + + def test_iterator_starts_from_locales_first_week_day(self): + i18n.get_locale = lambda: flexmock(first_week_day=1) + days = WeekDays('1111111') + indices = list(day.index for day in days) + assert indices == [1, 2, 3, 4, 5, 6, 0] + + def test_unicode(self): + i18n.get_locale = lambda: i18n.babel.Locale('fi') + days = WeekDays('1000100') + assert six.text_type(days) == u'maanantaina, perjantaina' + + def test_str(self): + i18n.get_locale = lambda: i18n.babel.Locale('fi') + days = WeekDays('1000100') + assert str(days) == 'maanantaina, perjantaina' diff --git a/python-sqlalchemy-utils/tests/relationships/__init__.py b/python-sqlalchemy-utils/tests/relationships/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/python-sqlalchemy-utils/tests/relationships/test_chained_join.py b/python-sqlalchemy-utils/tests/relationships/test_chained_join.py new file mode 100644 index 0000000..aff42d4 --- /dev/null +++ b/python-sqlalchemy-utils/tests/relationships/test_chained_join.py @@ -0,0 +1,107 @@ +from sqlalchemy_utils.relationships import chained_join +from tests import TestCase +from tests.mixins import ( + ThreeLevelDeepManyToMany, + ThreeLevelDeepOneToMany, + ThreeLevelDeepOneToOne +) + + +class TestChainedJoinFoDeepToManyToMany(ThreeLevelDeepManyToMany, TestCase): + dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' + create_tables = False + + def test_simple_join(self): + assert str(chained_join(self.Catalog.categories)) == ( + 'catalog_category JOIN category ON ' + 'category._id = catalog_category.category_id' + ) + + def test_two_relations(self): + sql = chained_join( + self.Catalog.categories, + self.Category.sub_categories + ) + assert str(sql) == ( + 'catalog_category JOIN category ON category._id = ' + 'catalog_category.category_id JOIN category_subcategory ON ' + 'category._id = category_subcategory.category_id JOIN ' + 'sub_category ON sub_category._id = ' + 'category_subcategory.subcategory_id' + ) + + def test_three_relations(self): + sql = chained_join( + self.Catalog.categories, + self.Category.sub_categories, + self.SubCategory.products + ) + assert str(sql) == ( + 'catalog_category JOIN category ON category._id = ' + 'catalog_category.category_id JOIN category_subcategory ON ' + 'category._id = category_subcategory.category_id JOIN sub_category' + ' ON sub_category._id = category_subcategory.subcategory_id JOIN ' + 'subcategory_product ON sub_category._id = ' + 'subcategory_product.subcategory_id JOIN product ON product._id =' + ' subcategory_product.product_id' + ) + + +class TestChainedJoinForDeepOneToMany(ThreeLevelDeepOneToMany, TestCase): + dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' + create_tables = False + + def test_simple_join(self): + assert str(chained_join(self.Catalog.categories)) == 'category' + + def test_two_relations(self): + sql = chained_join( + self.Catalog.categories, + self.Category.sub_categories + ) + assert str(sql) == ( + 'category JOIN sub_category ON category._id = ' + 'sub_category._category_id' + ) + + def test_three_relations(self): + sql = chained_join( + self.Catalog.categories, + self.Category.sub_categories, + self.SubCategory.products + ) + assert str(sql) == ( + 'category JOIN sub_category ON category._id = ' + 'sub_category._category_id JOIN product ON sub_category._id = ' + 'product._sub_category_id' + ) + + +class TestChainedJoinForDeepOneToOne(ThreeLevelDeepOneToOne, TestCase): + dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' + create_tables = False + + def test_simple_join(self): + assert str(chained_join(self.Catalog.category)) == 'category' + + def test_two_relations(self): + sql = chained_join( + self.Catalog.category, + self.Category.sub_category + ) + assert str(sql) == ( + 'category JOIN sub_category ON category._id = ' + 'sub_category._category_id' + ) + + def test_three_relations(self): + sql = chained_join( + self.Catalog.category, + self.Category.sub_category, + self.SubCategory.product + ) + assert str(sql) == ( + 'category JOIN sub_category ON category._id = ' + 'sub_category._category_id JOIN product ON sub_category._id = ' + 'product._sub_category_id' + ) diff --git a/python-sqlalchemy-utils/tests/relationships/test_select_aggregate.py b/python-sqlalchemy-utils/tests/relationships/test_select_aggregate.py new file mode 100644 index 0000000..f4fe687 --- /dev/null +++ b/python-sqlalchemy-utils/tests/relationships/test_select_aggregate.py @@ -0,0 +1,61 @@ +import sqlalchemy as sa + +from sqlalchemy_utils.aggregates import select_aggregate +from tests import TestCase +from tests.mixins import ThreeLevelDeepManyToMany + + +def normalize(sql): + return ' '.join(sql.replace('\n', '').split()) + + +class TestAggregateQueryForDeepToManyToMany( + ThreeLevelDeepManyToMany, + TestCase +): + dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' + create_tables = False + + def assert_sql(self, construct, sql): + assert normalize(str(construct)) == normalize(sql) + + def build_update(self, *relationships): + expr = sa.func.count(sa.text('1')) + return ( + self.Catalog.__table__.update().values( + _id=select_aggregate( + expr, + relationships + ).correlate(self.Catalog) + ) + ) + + def test_simple_join(self): + self.assert_sql( + self.build_update(self.Catalog.categories), + ( + '''UPDATE catalog SET _id=(SELECT count(1) AS count_1 + FROM category JOIN catalog_category ON category._id = + catalog_category.category_id WHERE catalog._id = + catalog_category.catalog_id)''' + ) + ) + + def test_two_relations(self): + self.assert_sql( + self.build_update( + self.Category.sub_categories, + self.Catalog.categories, + ), + ( + '''UPDATE catalog SET _id=(SELECT count(1) AS count_1 + FROM sub_category + JOIN category_subcategory + ON sub_category._id = category_subcategory.subcategory_id + JOIN category + ON category._id = category_subcategory.category_id + JOIN catalog_category + ON category._id = catalog_category.category_id + WHERE catalog._id = catalog_category.catalog_id)''' + ) + ) diff --git a/python-sqlalchemy-utils/tests/test_asserts.py b/python-sqlalchemy-utils/tests/test_asserts.py new file mode 100644 index 0000000..623ff06 --- /dev/null +++ b/python-sqlalchemy-utils/tests/test_asserts.py @@ -0,0 +1,149 @@ +import pytest +import sqlalchemy as sa +from sqlalchemy.dialects.postgresql import ARRAY + +from sqlalchemy_utils import ( + assert_max_length, + assert_max_value, + assert_min_value, + assert_non_nullable, + assert_nullable +) +from tests import TestCase + + +class AssertionTestCase(TestCase): + dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' + + def create_models(self): + class User(self.Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.String(20)) + age = sa.Column(sa.Integer, nullable=False) + email = sa.Column(sa.String(200), nullable=False, unique=True) + fav_numbers = sa.Column(ARRAY(sa.Integer)) + + __table_args__ = ( + sa.CheckConstraint(sa.and_(age >= 0, age <= 150)), + sa.CheckConstraint( + sa.and_( + sa.func.array_length(fav_numbers, 1) <= 8 + ) + ) + ) + + self.User = User + + def setup_method(self, method): + TestCase.setup_method(self, method) + user = self.User( + name='Someone', + email='someone@example.com', + age=15, + fav_numbers=[1, 2, 3] + ) + self.session.add(user) + self.session.commit() + self.user = user + + +class TestAssertMaxLengthWithArray(AssertionTestCase): + def test_with_max_length(self): + assert_max_length(self.user, 'fav_numbers', 8) + assert_max_length(self.user, 'fav_numbers', 8) + + def test_smaller_than_max_length(self): + with pytest.raises(AssertionError): + assert_max_length(self.user, 'fav_numbers', 7) + with pytest.raises(AssertionError): + assert_max_length(self.user, 'fav_numbers', 7) + + def test_bigger_than_max_length(self): + with pytest.raises(AssertionError): + assert_max_length(self.user, 'fav_numbers', 9) + with pytest.raises(AssertionError): + assert_max_length(self.user, 'fav_numbers', 9) + + +class TestAssertNonNullable(AssertionTestCase): + def test_non_nullable_column(self): + # Test everything twice so that session gets rolled back properly + assert_non_nullable(self.user, 'age') + assert_non_nullable(self.user, 'age') + + def test_nullable_column(self): + with pytest.raises(AssertionError): + assert_non_nullable(self.user, 'name') + with pytest.raises(AssertionError): + assert_non_nullable(self.user, 'name') + + +class TestAssertNullable(AssertionTestCase): + def test_nullable_column(self): + assert_nullable(self.user, 'name') + assert_nullable(self.user, 'name') + + def test_non_nullable_column(self): + with pytest.raises(AssertionError): + assert_nullable(self.user, 'age') + with pytest.raises(AssertionError): + assert_nullable(self.user, 'age') + + +class TestAssertMaxLength(AssertionTestCase): + def test_with_max_length(self): + assert_max_length(self.user, 'name', 20) + assert_max_length(self.user, 'name', 20) + + def test_with_non_nullable_column(self): + assert_max_length(self.user, 'email', 200) + assert_max_length(self.user, 'email', 200) + + def test_smaller_than_max_length(self): + with pytest.raises(AssertionError): + assert_max_length(self.user, 'name', 19) + with pytest.raises(AssertionError): + assert_max_length(self.user, 'name', 19) + + def test_bigger_than_max_length(self): + with pytest.raises(AssertionError): + assert_max_length(self.user, 'name', 21) + with pytest.raises(AssertionError): + assert_max_length(self.user, 'name', 21) + + +class TestAssertMinValue(AssertionTestCase): + def test_with_min_value(self): + assert_min_value(self.user, 'age', 0) + assert_min_value(self.user, 'age', 0) + + def test_smaller_than_min_value(self): + with pytest.raises(AssertionError): + assert_min_value(self.user, 'age', -1) + with pytest.raises(AssertionError): + assert_min_value(self.user, 'age', -1) + + def test_bigger_than_min_value(self): + with pytest.raises(AssertionError): + assert_min_value(self.user, 'age', 1) + with pytest.raises(AssertionError): + assert_min_value(self.user, 'age', 1) + + +class TestAssertMaxValue(AssertionTestCase): + def test_with_min_value(self): + assert_max_value(self.user, 'age', 150) + assert_max_value(self.user, 'age', 150) + + def test_smaller_than_max_value(self): + with pytest.raises(AssertionError): + assert_max_value(self.user, 'age', 149) + with pytest.raises(AssertionError): + assert_max_value(self.user, 'age', 149) + + def test_bigger_than_max_value(self): + with pytest.raises(AssertionError): + assert_max_value(self.user, 'age', 151) + with pytest.raises(AssertionError): + assert_max_value(self.user, 'age', 151) diff --git a/python-sqlalchemy-utils/tests/test_auto_delete_orphans.py b/python-sqlalchemy-utils/tests/test_auto_delete_orphans.py new file mode 100644 index 0000000..6efb5b9 --- /dev/null +++ b/python-sqlalchemy-utils/tests/test_auto_delete_orphans.py @@ -0,0 +1,117 @@ +import sqlalchemy as sa +from pytest import raises + +from sqlalchemy_utils import auto_delete_orphans, ImproperlyConfigured +from tests import TestCase + + +class TestAutoDeleteOrphans(TestCase): + def create_models(self): + tagging = sa.Table( + 'tagging', + self.Base.metadata, + sa.Column( + 'tag_id', + sa.Integer, + sa.ForeignKey('tag.id', ondelete='cascade'), + primary_key=True + ), + sa.Column( + 'entry_id', + sa.Integer, + sa.ForeignKey('entry.id', ondelete='cascade'), + primary_key=True + ) + ) + + class Tag(self.Base): + __tablename__ = 'tag' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.String(100), unique=True, nullable=False) + + def __init__(self, name=None): + self.name = name + + class Entry(self.Base): + __tablename__ = 'entry' + + id = sa.Column(sa.Integer, primary_key=True) + + tags = sa.orm.relationship( + 'Tag', + secondary=tagging, + backref='entries' + ) + + auto_delete_orphans(Entry.tags) + + self.Tag = Tag + self.Entry = Entry + + def test_orphan_deletion(self): + r1 = self.Entry() + r2 = self.Entry() + r3 = self.Entry() + t1, t2, t3, t4 = ( + self.Tag('t1'), + self.Tag('t2'), + self.Tag('t3'), + self.Tag('t4') + ) + + r1.tags.extend([t1, t2]) + r2.tags.extend([t2, t3]) + r3.tags.extend([t4]) + self.session.add_all([r1, r2, r3]) + + assert self.session.query(self.Tag).count() == 4 + r2.tags.remove(t2) + assert self.session.query(self.Tag).count() == 4 + r1.tags.remove(t2) + assert self.session.query(self.Tag).count() == 3 + r1.tags.remove(t1) + assert self.session.query(self.Tag).count() == 2 + + +class TestAutoDeleteOrphansWithoutBackref(TestCase): + def create_models(self): + tagging = sa.Table( + 'tagging', + self.Base.metadata, + sa.Column( + 'tag_id', + sa.Integer, + sa.ForeignKey('tag.id', ondelete='cascade'), + primary_key=True + ), + sa.Column( + 'entry_id', + sa.Integer, + sa.ForeignKey('entry.id', ondelete='cascade'), + primary_key=True + ) + ) + + class Tag(self.Base): + __tablename__ = 'tag' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.String(100), unique=True, nullable=False) + + def __init__(self, name=None): + self.name = name + + class Entry(self.Base): + __tablename__ = 'entry' + + id = sa.Column(sa.Integer, primary_key=True) + + tags = sa.orm.relationship( + 'Tag', + secondary=tagging + ) + + self.Entry = Entry + + def test_orphan_deletion(self): + with raises(ImproperlyConfigured): + auto_delete_orphans(self.Entry.tags) diff --git a/python-sqlalchemy-utils/tests/test_case_insensitive_comparator.py b/python-sqlalchemy-utils/tests/test_case_insensitive_comparator.py new file mode 100644 index 0000000..19812e1 --- /dev/null +++ b/python-sqlalchemy-utils/tests/test_case_insensitive_comparator.py @@ -0,0 +1,50 @@ +import sqlalchemy as sa + +from sqlalchemy_utils import EmailType +from tests import TestCase + + +class TestCaseInsensitiveComparator(TestCase): + def create_models(self): + class User(self.Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + email = sa.Column(EmailType) + + def __repr__(self): + return 'Building(%r)' % self.id + + self.User = User + + def test_supports_equals(self): + query = ( + self.session.query(self.User) + .filter(self.User.email == u'email@example.com') + ) + + assert '"user".email = lower(:lower_1)' in str(query) + + def test_supports_in_(self): + query = ( + self.session.query(self.User) + .filter(self.User.email.in_([u'email@example.com', u'a'])) + ) + assert ( + '"user".email IN (lower(:lower_1), lower(:lower_2))' + in str(query) + ) + + def test_supports_notin_(self): + query = ( + self.session.query(self.User) + .filter(self.User.email.notin_([u'email@example.com', u'a'])) + ) + assert ( + '"user".email NOT IN (lower(:lower_1), lower(:lower_2))' + in str(query) + ) + + def test_does_not_apply_lower_to_types_that_are_already_lowercased(self): + assert str(self.User.email == self.User.email) == ( + '"user".email = "user".email' + ) diff --git a/python-sqlalchemy-utils/tests/test_expression_parser.py b/python-sqlalchemy-utils/tests/test_expression_parser.py new file mode 100644 index 0000000..d7cccac --- /dev/null +++ b/python-sqlalchemy-utils/tests/test_expression_parser.py @@ -0,0 +1,88 @@ +import sqlalchemy as sa +from sqlalchemy.sql.elements import Cast, Null + +from sqlalchemy_utils import ExpressionParser +from tests import TestCase + + +class MyExpressionParser(ExpressionParser): + def __init__(self, some_class): + self.parent = some_class + + def column(self, column): + return getattr(self.parent, column.key) + + def instrumented_attribute(self, column): + return getattr(self.parent, column.key) + + +class TestExpressionParser(TestCase): + create_tables = False + + def setup_method(self, method): + TestCase.setup_method(self, method) + self.parser = MyExpressionParser(self.Category) + + def test_false_expression(self): + expr = self.parser(self.User.name.isnot(False)) + assert str(expr) == 'category.name IS NOT 0' + + def test_true_expression(self): + expr = self.parser(self.User.name.isnot(True)) + assert str(expr) == 'category.name IS NOT 1' + + def test_unary_expression(self): + expr = self.parser(~ self.User.name) + assert str(expr) == 'NOT category.name' + + def test_in_expression(self): + expr = self.parser(self.User.name.in_([2, 3])) + assert str(expr) == 'category.name IN (:name_1, :name_2)' + + def test_boolean_expression(self): + expr = self.parser(self.User.name == False) # noqa + assert str(expr) == 'category.name = 0' + + def test_label(self): + expr = self.parser(self.User.name.label('some_name')) + assert str(expr) == 'category.name' + + def test_like(self): + expr = self.parser(self.User.name.like(u'something')) + assert str(expr) == 'category.name LIKE :name_1' + + def test_cast(self): + expr = self.parser(Cast(self.User.name, sa.UnicodeText)) + assert str(expr) == 'CAST(category.name AS TEXT)' + + def test_case(self): + expr = self.parser( + sa.case( + [ + (self.User.name == 'wendy', 'W'), + (self.User.name == 'jack', 'J') + ], + else_='E' + ) + ) + assert str(expr) == ( + 'CASE WHEN (category.name = :name_1) ' + 'THEN :param_1 WHEN (category.name = :name_2) ' + 'THEN :param_2 ELSE :param_3 END' + ) + + def test_tuple(self): + expr = self.parser( + sa.tuple_(self.User.name, 3).in_([(u'someone', 3)]) + ) + assert str(expr) == ( + '(category.name, :param_1) IN ((:param_2, :param_3))' + ) + + def test_null(self): + expr = self.parser(self.User.name == Null()) + assert str(expr) == 'category.name IS NULL' + + def test_instrumented_attribute(self): + expr = self.parser(self.User.name) + assert str(expr) == 'Category.name' diff --git a/python-sqlalchemy-utils/tests/test_expressions.py b/python-sqlalchemy-utils/tests/test_expressions.py new file mode 100644 index 0000000..1bcd223 --- /dev/null +++ b/python-sqlalchemy-utils/tests/test_expressions.py @@ -0,0 +1,157 @@ +import sqlalchemy as sa +from pytest import raises +from sqlalchemy.dialects import postgresql + +from sqlalchemy_utils import Asterisk, row_to_json +from sqlalchemy_utils.expressions import explain, explain_analyze +from tests import TestCase + + +class ExpressionTestCase(TestCase): + dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' + + def create_models(self): + class Article(self.Base): + __tablename__ = 'article' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + content = sa.Column(sa.UnicodeText) + + self.Article = Article + + def assert_startswith(self, query, query_part): + assert str( + query.compile(dialect=postgresql.dialect()) + ).startswith(query_part) + # Check that query executes properly + self.session.execute(query) + + +class TestExplain(ExpressionTestCase): + def test_render_explain(self): + self.assert_startswith( + explain(self.session.query(self.Article)), + 'EXPLAIN SELECT' + ) + + def test_render_explain_with_analyze(self): + self.assert_startswith( + explain(self.session.query(self.Article), analyze=True), + 'EXPLAIN (ANALYZE true) SELECT' + ) + + def test_with_string_as_stmt_param(self): + self.assert_startswith( + explain('SELECT 1 FROM article'), + 'EXPLAIN SELECT' + ) + + def test_format(self): + self.assert_startswith( + explain('SELECT 1 FROM article', format='json'), + 'EXPLAIN (FORMAT json) SELECT' + ) + + def test_timing(self): + self.assert_startswith( + explain('SELECT 1 FROM article', analyze=True, timing=False), + 'EXPLAIN (ANALYZE true, TIMING false) SELECT' + ) + + def test_verbose(self): + self.assert_startswith( + explain('SELECT 1 FROM article', verbose=True), + 'EXPLAIN (VERBOSE true) SELECT' + ) + + def test_buffers(self): + self.assert_startswith( + explain('SELECT 1 FROM article', analyze=True, buffers=True), + 'EXPLAIN (ANALYZE true, BUFFERS true) SELECT' + ) + + def test_costs(self): + self.assert_startswith( + explain('SELECT 1 FROM article', costs=False), + 'EXPLAIN (COSTS false) SELECT' + ) + + +class TestExplainAnalyze(ExpressionTestCase): + def test_render_explain_analyze(self): + assert str( + explain_analyze(self.session.query(self.Article)) + .compile( + dialect=postgresql.dialect() + ) + ).startswith('EXPLAIN (ANALYZE true) SELECT') + + +class TestAsterisk(object): + def test_with_table_object(self): + Base = sa.ext.declarative.declarative_base() + + class Article(Base): + __tablename__ = 'article' + id = sa.Column(sa.Integer, primary_key=True) + + assert str(Asterisk(Article.__table__)) == 'article.*' + + def test_with_quoted_identifier(self): + Base = sa.ext.declarative.declarative_base() + + class User(Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + + assert str(Asterisk(User.__table__).compile( + dialect=postgresql.dialect() + )) == '"user".*' + + +class TestRowToJson(object): + def test_compiler_with_default_dialect(self): + with raises(sa.exc.CompileError): + str(row_to_json(sa.text('article.*'))) + + def test_compiler_with_postgresql(self): + assert str(row_to_json(sa.text('article.*')).compile( + dialect=postgresql.dialect() + )) == 'row_to_json(article.*)' + + def test_type(self): + assert isinstance( + sa.func.row_to_json(sa.text('article.*')).type, + postgresql.JSON + ) + + +class TestArrayAgg(object): + def test_compiler_with_default_dialect(self): + with raises(sa.exc.CompileError): + str(sa.func.array_agg(sa.text('u.name'))) + + def test_compiler_with_postgresql(self): + assert str(sa.func.array_agg(sa.text('u.name')).compile( + dialect=postgresql.dialect() + )) == "array_agg(u.name)" + + def test_type(self): + assert isinstance( + sa.func.array_agg(sa.text('u.name')).type, + postgresql.ARRAY + ) + + def test_array_agg_with_default(self): + Base = sa.ext.declarative.declarative_base() + + class Article(Base): + __tablename__ = 'article' + id = sa.Column(sa.Integer, primary_key=True) + + assert str(sa.func.array_agg(Article.id, [1]).compile( + dialect=postgresql.dialect() + )) == ( + 'coalesce(array_agg(article.id), CAST(ARRAY[%(param_1)s]' + ' AS INTEGER[]))' + ) diff --git a/python-sqlalchemy-utils/tests/test_instant_defaults_listener.py b/python-sqlalchemy-utils/tests/test_instant_defaults_listener.py new file mode 100644 index 0000000..0ec444a --- /dev/null +++ b/python-sqlalchemy-utils/tests/test_instant_defaults_listener.py @@ -0,0 +1,27 @@ +from datetime import datetime + +import sqlalchemy as sa + +from sqlalchemy_utils.listeners import force_instant_defaults +from tests import TestCase + +force_instant_defaults() + + +class TestInstantDefaultListener(TestCase): + def create_models(self): + class Article(self.Base): + __tablename__ = 'article' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255), default=u'Some article') + created_at = sa.Column(sa.DateTime, default=datetime.now) + + self.Article = Article + + def test_assigns_defaults_on_object_construction(self): + article = self.Article() + assert article.name == u'Some article' + + def test_callables_as_defaults(self): + article = self.Article() + assert isinstance(article.created_at, datetime) diff --git a/python-sqlalchemy-utils/tests/test_instrumented_list.py b/python-sqlalchemy-utils/tests/test_instrumented_list.py new file mode 100644 index 0000000..525d46b --- /dev/null +++ b/python-sqlalchemy-utils/tests/test_instrumented_list.py @@ -0,0 +1,14 @@ +from tests import TestCase + + +class TestInstrumentedList(TestCase): + def test_any_returns_true_if_member_has_attr_defined(self): + category = self.Category() + category.articles.append(self.Article()) + category.articles.append(self.Article(name=u'some name')) + assert category.articles.any('name') + + def test_any_returns_false_if_no_member_has_attr_defined(self): + category = self.Category() + category.articles.append(self.Article()) + assert not category.articles.any('name') diff --git a/python-sqlalchemy-utils/tests/test_models.py b/python-sqlalchemy-utils/tests/test_models.py new file mode 100644 index 0000000..83aa912 --- /dev/null +++ b/python-sqlalchemy-utils/tests/test_models.py @@ -0,0 +1,39 @@ +from datetime import datetime + +import sqlalchemy as sa + +from sqlalchemy_utils import Timestamp +from tests import TestCase + + +class TestTimestamp(TestCase): + + def create_models(self): + class Article(self.Base, Timestamp): + __tablename__ = 'article' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255), default=u'Some article') + + self.Article = Article + + def test_created(self): + then = datetime.utcnow() + article = self.Article() + + self.session.add(article) + self.session.commit() + + assert article.created >= then and article.created <= datetime.utcnow() + + def test_updated(self): + article = self.Article() + + self.session.add(article) + self.session.commit() + + then = datetime.utcnow() + article.name = u"Something" + + self.session.commit() + + assert article.updated >= then and article.updated <= datetime.utcnow() diff --git a/python-sqlalchemy-utils/tests/test_path.py b/python-sqlalchemy-utils/tests/test_path.py new file mode 100644 index 0000000..06af498 --- /dev/null +++ b/python-sqlalchemy-utils/tests/test_path.py @@ -0,0 +1,182 @@ +import six +import sqlalchemy as sa +from pytest import mark +from sqlalchemy.util.langhelpers import symbol + +from sqlalchemy_utils.path import AttrPath, Path +from tests import TestCase + + +class TestAttrPath(TestCase): + def create_models(self): + class Document(self.Base): + __tablename__ = 'document' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + locale = sa.Column(sa.String(10)) + + class Section(self.Base): + __tablename__ = 'section' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + locale = sa.Column(sa.String(10)) + + document_id = sa.Column( + sa.Integer, sa.ForeignKey(Document.id) + ) + + document = sa.orm.relationship(Document, backref='sections') + + class SubSection(self.Base): + __tablename__ = 'subsection' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + locale = sa.Column(sa.String(10)) + + section_id = sa.Column( + sa.Integer, sa.ForeignKey(Section.id) + ) + + section = sa.orm.relationship(Section, backref='subsections') + + self.Document = Document + self.Section = Section + self.SubSection = SubSection + + @mark.parametrize( + ('class_', 'path', 'direction'), + ( + ('SubSection', 'section', symbol('MANYTOONE')), + ) + ) + def test_direction(self, class_, path, direction): + assert ( + AttrPath(getattr(self, class_), path).direction == direction + ) + + def test_invert(self): + path = ~ AttrPath(self.SubSection, 'section.document') + assert path.parts == [ + self.Document.sections, + self.Section.subsections + ] + assert str(path.path) == 'sections.subsections' + + def test_len(self): + len(AttrPath(self.SubSection, 'section.document')) == 2 + + def test_init(self): + path = AttrPath(self.SubSection, 'section.document') + assert path.class_ == self.SubSection + assert path.path == Path('section.document') + + def test_iter(self): + path = AttrPath(self.SubSection, 'section.document') + assert list(path) == [ + self.SubSection.section, + self.Section.document + ] + + def test_repr(self): + path = AttrPath(self.SubSection, 'section.document') + assert repr(path) == ( + "AttrPath(SubSection, 'section.document')" + ) + + def test_index(self): + path = AttrPath(self.SubSection, 'section.document') + assert path.index(self.Section.document) == 1 + assert path.index(self.SubSection.section) == 0 + + def test_getitem(self): + path = AttrPath(self.SubSection, 'section.document') + assert path[0] is self.SubSection.section + assert path[1] is self.Section.document + + def test_getitem_with_slice(self): + path = AttrPath(self.SubSection, 'section.document') + assert path[:] == AttrPath(self.SubSection, 'section.document') + assert path[:-1] == AttrPath(self.SubSection, 'section') + assert path[1:] == AttrPath(self.Section, 'document') + + def test_eq(self): + assert ( + AttrPath(self.SubSection, 'section.document') == + AttrPath(self.SubSection, 'section.document') + ) + assert not ( + AttrPath(self.SubSection, 'section') == + AttrPath(self.SubSection, 'section.document') + ) + + def test_ne(self): + assert not ( + AttrPath(self.SubSection, 'section.document') != + AttrPath(self.SubSection, 'section.document') + ) + assert ( + AttrPath(self.SubSection, 'section') != + AttrPath(self.SubSection, 'section.document') + ) + + +class TestPath(object): + def test_init(self): + path = Path('attr.attr2') + assert path.path == 'attr.attr2' + + def test_init_with_path_object(self): + path = Path(Path('attr.attr2')) + assert path.path == 'attr.attr2' + + def test_iter(self): + path = Path('s.s2.s3') + assert list(path) == ['s', 's2', 's3'] + + @mark.parametrize(('path', 'length'), ( + (Path('s.s2.s3'), 3), + (Path('s.s2'), 2), + (Path(''), 0) + )) + def test_len(self, path, length): + return len(path) == length + + def test_reversed(self): + path = Path('s.s2.s3') + assert list(reversed(path)) == ['s3', 's2', 's'] + + def test_repr(self): + path = Path('s.s2') + assert repr(path) == "Path('s.s2')" + + def test_getitem(self): + path = Path('s.s2') + assert path[0] == 's' + assert path[1] == 's2' + + def test_str(self): + assert str(Path('s.s2')) == 's.s2' + + def test_index(self): + assert Path('s.s2.s3').index('s2') == 1 + + def test_unicode(self): + assert six.text_type(Path('s.s2')) == u's.s2' + + def test_getitem_with_slice(self): + path = Path('s.s2.s3') + assert path[1:] == Path('s2.s3') + + @mark.parametrize(('test', 'result'), ( + (Path('s.s2') == Path('s.s2'), True), + (Path('s.s2') == Path('s.s3'), False) + )) + def test_eq(self, test, result): + assert test is result + + @mark.parametrize(('test', 'result'), ( + (Path('s.s2') != Path('s.s2'), False), + (Path('s.s2') != Path('s.s3'), True) + )) + def test_ne(self, test, result): + assert test is result diff --git a/python-sqlalchemy-utils/tests/test_proxy_dict.py b/python-sqlalchemy-utils/tests/test_proxy_dict.py new file mode 100644 index 0000000..a87f39c --- /dev/null +++ b/python-sqlalchemy-utils/tests/test_proxy_dict.py @@ -0,0 +1,119 @@ +import sqlalchemy as sa +from flexmock import flexmock + +from sqlalchemy_utils import proxy_dict, ProxyDict +from tests import TestCase + + +class TestProxyDict(TestCase): + def create_models(self): + class Article(self.Base): + __tablename__ = 'article' + + id = sa.Column(sa.Integer, autoincrement=True, primary_key=True) + description = sa.Column(sa.UnicodeText) + _translations = sa.orm.relationship( + 'ArticleTranslation', + lazy='dynamic', + cascade='all, delete-orphan', + passive_deletes=True, + backref=sa.orm.backref('parent'), + ) + + @property + def translations(self): + return proxy_dict( + self, + '_translations', + ArticleTranslation.locale + ) + + class ArticleTranslation(self.Base): + __tablename__ = 'article_translation' + + id = sa.Column( + sa.Integer, + sa.ForeignKey(Article.id), + autoincrement=True, + primary_key=True + ) + locale = sa.Column(sa.String(10), primary_key=True) + name = sa.Column(sa.UnicodeText) + + self.Article = Article + self.ArticleTranslation = ArticleTranslation + + def test_access_key_for_pending_parent(self): + article = self.Article() + self.session.add(article) + assert article.translations['en'] + + def test_access_key_for_transient_parent(self): + article = self.Article() + assert article.translations['en'] + + def test_cache(self): + article = self.Article() + ( + flexmock(ProxyDict) + .should_receive('fetch') + .once() + ) + self.session.add(article) + self.session.commit() + article.translations['en'] + article.translations['en'] + + def test_set_updates_cache(self): + article = self.Article() + ( + flexmock(ProxyDict) + .should_receive('fetch') + .once() + ) + self.session.add(article) + self.session.commit() + article.translations['en'] + article.translations['en'] = self.ArticleTranslation( + locale='en', + name=u'something' + ) + article.translations['en'] + + def test_contains_efficiency(self): + article = self.Article() + self.session.add(article) + self.session.commit() + article.id + query_count = self.connection.query_count + 'en' in article.translations + 'en' in article.translations + 'en' in article.translations + assert self.connection.query_count == query_count + 1 + + def test_getitem_with_none_value_in_cache(self): + article = self.Article() + self.session.add(article) + self.session.commit() + article.id + 'en' in article.translations + assert article.translations['en'] + + def test_contains(self): + article = self.Article() + assert 'en' not in article.translations + # does not auto-append new translation + assert 'en' not in article.translations + + def test_committing_session_empties_proxy_dict_cache(self): + article = self.Article() + ( + flexmock(ProxyDict) + .should_receive('fetch') + .twice() + ) + self.session.add(article) + self.session.commit() + article.translations['en'] + self.session.commit() + article.translations['en'] diff --git a/python-sqlalchemy-utils/tests/test_query_chain.py b/python-sqlalchemy-utils/tests/test_query_chain.py new file mode 100644 index 0000000..89b4ea4 --- /dev/null +++ b/python-sqlalchemy-utils/tests/test_query_chain.py @@ -0,0 +1,93 @@ +import sqlalchemy as sa + +from sqlalchemy_utils import QueryChain +from tests import TestCase + + +class TestQueryChain(TestCase): + def create_models(self): + class User(self.Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + + class Article(self.Base): + __tablename__ = 'article' + id = sa.Column(sa.Integer, primary_key=True) + + class BlogPost(self.Base): + __tablename__ = 'blog_post' + id = sa.Column(sa.Integer, primary_key=True) + + self.User = User + self.Article = Article + self.BlogPost = BlogPost + + def setup_method(self, method): + TestCase.setup_method(self, method) + self.users = [ + self.User(), + self.User() + ] + self.articles = [ + self.Article(), + self.Article(), + self.Article(), + self.Article() + ] + self.posts = [ + self.BlogPost(), + self.BlogPost(), + self.BlogPost(), + ] + + self.session.add_all(self.users) + self.session.add_all(self.articles) + self.session.add_all(self.posts) + self.session.commit() + + self.chain = QueryChain( + [ + self.session.query(self.User).order_by('id'), + self.session.query(self.Article).order_by('id'), + self.session.query(self.BlogPost).order_by('id') + ] + ) + + def test_iter(self): + assert len(list(self.chain)) == 9 + + def test_iter_with_limit(self): + chain = self.chain.limit(4) + objects = list(chain) + assert self.users == objects[0:2] + assert self.articles[0:2] == objects[2:] + + def test_iter_with_offset(self): + chain = self.chain.offset(3) + objects = list(chain) + assert self.articles[1:] + self.posts == objects + + def test_iter_with_limit_and_offset(self): + chain = self.chain.offset(3).limit(4) + objects = list(chain) + assert self.articles[1:] + self.posts[0:1] == objects + + def test_iter_with_offset_spanning_multiple_queries(self): + chain = self.chain.offset(7) + objects = list(chain) + assert self.posts[1:] == objects + + def test_repr(self): + assert repr(self.chain) == '' % id(self.chain) + + def test_getitem_with_slice(self): + chain = self.chain[1:] + assert chain._offset == 1 + assert chain._limit is None + + def test_getitem_with_single_key(self): + article = self.chain[2] + assert article == self.articles[0] + + def test_count(self): + assert self.chain.count() == 9 diff --git a/python-sqlalchemy-utils/tests/test_sort_query.py b/python-sqlalchemy-utils/tests/test_sort_query.py new file mode 100644 index 0000000..d72ab09 --- /dev/null +++ b/python-sqlalchemy-utils/tests/test_sort_query.py @@ -0,0 +1,334 @@ +import sqlalchemy as sa +from pytest import raises + +from sqlalchemy_utils import sort_query +from sqlalchemy_utils.functions import QuerySorterException +from tests import assert_contains, TestCase + + +class TestSortQuery(TestCase): + def test_without_sort_param_returns_the_query_object_untouched(self): + query = self.session.query(self.Article) + query = sort_query(query, '') + assert query == query + + def test_column_ascending(self): + query = sort_query(self.session.query(self.Article), 'name') + assert_contains('ORDER BY article.name ASC', query) + + def test_column_descending(self): + query = sort_query(self.session.query(self.Article), '-name') + assert_contains('ORDER BY article.name DESC', query) + + def test_skips_unknown_columns(self): + query = self.session.query(self.Article) + query = sort_query(query, '-unknown') + assert query == query + + def test_non_silent_mode(self): + query = self.session.query(self.Article) + with raises(QuerySorterException): + sort_query(query, '-unknown', silent=False) + + def test_join(self): + query = ( + self.session.query(self.Article) + .join(self.Article.category) + ) + query = sort_query(query, 'name', silent=False) + assert_contains('ORDER BY article.name ASC', query) + + def test_calculated_value_ascending(self): + query = self.session.query( + self.Category, sa.func.count(self.Article.id).label('articles') + ) + query = sort_query(query, 'articles') + assert_contains('ORDER BY articles ASC', query) + + def test_calculated_value_descending(self): + query = self.session.query( + self.Category, sa.func.count(self.Article.id).label('articles') + ) + query = sort_query(query, '-articles') + assert_contains('ORDER BY articles DESC', query) + + def test_subqueried_scalar(self): + article_count = ( + sa.sql.select( + [sa.func.count(self.Article.id)], + from_obj=[self.Article.__table__] + ) + .where(self.Article.category_id == self.Category.id) + .correlate(self.Category.__table__) + ) + + query = self.session.query( + self.Category, article_count.label('articles') + ) + query = sort_query(query, '-articles') + assert_contains('ORDER BY articles DESC', query) + + def test_aliased_joined_entity(self): + alias = sa.orm.aliased(self.Category, name='categories') + query = self.session.query( + self.Article + ).join( + alias, self.Article.category + ) + query = sort_query(query, '-categories-name') + assert_contains('ORDER BY categories.name DESC', query) + + def test_joined_table_column(self): + query = self.session.query(self.Article).join(self.Article.category) + query = sort_query(query, 'category-name') + assert_contains('category.name ASC', query) + + def test_multiple_columns(self): + query = self.session.query(self.Article) + query = sort_query(query, 'name', 'id') + assert_contains('article.name ASC, article.id ASC', query) + + def test_column_property(self): + self.Category.article_count = sa.orm.column_property( + sa.select([sa.func.count(self.Article.id)]) + .where(self.Article.category_id == self.Category.id) + .label('article_count') + ) + + query = self.session.query(self.Category) + query = sort_query(query, 'article_count') + assert_contains('article_count ASC', query) + + def test_column_property_descending(self): + self.Category.article_count = sa.orm.column_property( + sa.select([sa.func.count(self.Article.id)]) + .where(self.Article.category_id == self.Category.id) + .label('article_count') + ) + + query = self.session.query(self.Category) + query = sort_query(query, '-article_count') + assert_contains('article_count DESC', query) + + def test_relationship_property(self): + query = self.session.query(self.Category) + query = sort_query(query, 'articles') + assert 'ORDER BY' not in str(query) + + def test_regular_property(self): + query = self.session.query(self.Category) + query = sort_query(query, 'name_alias') + assert 'ORDER BY' not in str(query) + + def test_synonym_property(self): + query = self.session.query(self.Category) + query = sort_query(query, 'name_synonym') + assert_contains('ORDER BY category.name ASC', query) + + def test_hybrid_property(self): + query = self.session.query(self.Category) + query = sort_query(query, 'articles_count') + assert_contains('ORDER BY (SELECT count(article.id) AS count_1', query) + + def test_hybrid_property_descending(self): + query = self.session.query(self.Category) + query = sort_query(query, '-articles_count') + assert_contains( + 'ORDER BY (SELECT count(article.id) AS count_1', + query + ) + assert ' DESC' in str(query) + + def test_assigned_hybrid_property(self): + def getter(self): + return self.name + + self.Article.some_hybrid = sa.ext.hybrid.hybrid_property( + fget=getter + ) + query = self.session.query(self.Article) + query = sort_query(query, 'some_hybrid') + assert_contains('ORDER BY article.name ASC', query) + + def test_with_mapper_and_column_property(self): + class Apple(self.Base): + __tablename__ = 'apple' + id = sa.Column(sa.Integer, primary_key=True) + article_id = sa.Column(sa.Integer, sa.ForeignKey(self.Article.id)) + + self.Article.apples = sa.orm.relationship(Apple) + + self.Article.apple_count = sa.orm.column_property( + sa.select([sa.func.count(Apple.id)]) + .where(Apple.article_id == self.Article.id) + .correlate(self.Article.__table__) + .label('apple_count'), + deferred=True + ) + query = ( + self.session.query(sa.inspect(self.Article)) + .outerjoin(self.Article.apples) + .options( + sa.orm.undefer(self.Article.apple_count) + ) + .options(sa.orm.contains_eager(self.Article.apples)) + ) + query = sort_query(query, 'apple_count') + assert 'ORDER BY apple_count' in str(query) + + def test_table(self): + query = self.session.query(self.Article.__table__) + query = sort_query(query, 'name') + assert_contains('ORDER BY article.name', query) + + +class TestSortQueryRelationshipCounts(TestCase): + """ + Currently this doesn't work with SQLite + """ + dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' + + def test_relation_hybrid_property(self): + query = ( + self.session.query(self.Article) + .join(self.Article.category) + ).group_by(self.Article.id) + query = sort_query(query, '-category-articles_count') + assert_contains('ORDER BY (SELECT count(article.id) AS count_1', query) + + def test_aliased_hybrid_property(self): + alias = sa.orm.aliased( + self.Category, + name='categories' + ) + query = ( + self.session.query(self.Article) + .outerjoin(alias, self.Article.category) + .options( + sa.orm.contains_eager(self.Article.category, alias=alias) + ) + ).group_by(alias.id, self.Article.id) + query = sort_query(query, '-categories-articles_count') + assert_contains('ORDER BY (SELECT count(article.id) AS count_1', query) + + +class TestSortQueryWithPolymorphicInheritance(TestCase): + """ + Currently this doesn't work with SQLite + """ + dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' + + def create_models(self): + class TextItem(self.Base): + __tablename__ = 'text_item' + id = sa.Column(sa.Integer, primary_key=True) + + type = sa.Column(sa.Unicode(255)) + + __mapper_args__ = { + 'polymorphic_on': type, + 'with_polymorphic': '*' + } + + class Article(TextItem): + __tablename__ = 'article' + id = sa.Column( + sa.Integer, sa.ForeignKey(TextItem.id), primary_key=True + ) + category = sa.Column(sa.Unicode(255)) + __mapper_args__ = { + 'polymorphic_identity': u'article' + } + + self.TextItem = TextItem + self.Article = Article + + def test_column_property(self): + self.TextItem.item_count = sa.orm.column_property( + sa.select( + [ + sa.func.count('1') + ], + ) + .select_from(self.TextItem.__table__) + .label('item_count') + ) + + query = sort_query( + self.session.query(self.TextItem), + 'item_count' + ) + assert_contains('ORDER BY item_count', query) + + def test_child_class_attribute(self): + query = sort_query( + self.session.query(self.TextItem), + 'category' + ) + assert_contains('ORDER BY article.category ASC', query) + + def test_with_ambiguous_column(self): + query = sort_query( + self.session.query(self.TextItem), + 'id' + ) + assert_contains('ORDER BY text_item.id ASC', query) + + +class TestSortQueryWithCustomPolymorphic(TestCase): + """ + Currently this doesn't work with SQLite + """ + dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' + + def create_models(self): + class TextItem(self.Base): + __tablename__ = 'text_item' + id = sa.Column(sa.Integer, primary_key=True) + + type = sa.Column(sa.Unicode(255)) + + __mapper_args__ = { + 'polymorphic_on': type, + } + + class Article(TextItem): + __tablename__ = 'article' + id = sa.Column( + sa.Integer, sa.ForeignKey(TextItem.id), primary_key=True + ) + category = sa.Column(sa.Unicode(255)) + __mapper_args__ = { + 'polymorphic_identity': u'article' + } + + class BlogPost(TextItem): + __tablename__ = 'blog_post' + id = sa.Column( + sa.Integer, sa.ForeignKey(TextItem.id), primary_key=True + ) + __mapper_args__ = { + 'polymorphic_identity': u'blog_post' + } + + self.TextItem = TextItem + self.Article = Article + self.BlogPost = BlogPost + + def test_with_unknown_column(self): + query = sort_query( + self.session.query( + sa.orm.with_polymorphic(self.TextItem, [self.BlogPost]) + ), + 'category' + ) + assert 'ORDER BY' not in str(query) + + def test_with_existing_column(self): + query = sort_query( + self.session.query( + sa.orm.with_polymorphic(self.TextItem, [self.Article]) + ), + 'category' + ) + assert 'ORDER BY' in str(query) diff --git a/python-sqlalchemy-utils/tests/test_translation_hybrid.py b/python-sqlalchemy-utils/tests/test_translation_hybrid.py new file mode 100644 index 0000000..dc81f65 --- /dev/null +++ b/python-sqlalchemy-utils/tests/test_translation_hybrid.py @@ -0,0 +1,107 @@ +import sqlalchemy as sa +from flexmock import flexmock +from pytest import mark +from sqlalchemy.dialects.postgresql import HSTORE + +from sqlalchemy_utils import i18n, TranslationHybrid # noqa +from tests import TestCase + + +@mark.skipif('i18n.babel is None') +class TestTranslationHybrid(TestCase): + dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' + + def create_models(self): + class City(self.Base): + __tablename__ = 'city' + id = sa.Column(sa.Integer, primary_key=True) + name_translations = sa.Column(HSTORE) + name = self.translation_hybrid(name_translations) + locale = 'en' + + self.City = City + + def setup_method(self, method): + self.translation_hybrid = TranslationHybrid('fi', 'en') + TestCase.setup_method(self, method) + + def test_using_hybrid_as_constructor(self): + city = self.City(name='Helsinki') + assert city.name_translations['fi'] == 'Helsinki' + + def test_if_no_translation_exists_returns_none(self): + city = self.City() + assert city.name is None + + def test_custom_default_value(self): + self.translation_hybrid.default_value = 'Some value' + city = self.City() + assert city.name is 'Some value' + + def test_fall_back_to_default_translation(self): + city = self.City(name_translations={'en': 'Helsinki'}) + self.translation_hybrid.current_locale = 'sv' + assert city.name == 'Helsinki' + + def test_fallback_to_dynamic_locale(self): + self.translation_hybrid.current_locale = 'en' + self.translation_hybrid.default_locale = lambda self: self.locale + city = self.City(name_translations={}) + city.locale = 'fi' + city.name_translations['fi'] = 'Helsinki' + + assert city.name == 'Helsinki' + + @mark.parametrize( + ('name_translations', 'name'), + ( + ({'fi': 'Helsinki', 'en': 'Helsing'}, 'Helsinki'), + ({'en': 'Helsinki'}, 'Helsinki'), + ({'fi': 'Helsinki'}, 'Helsinki'), + ({}, None), + ) + ) + def test_hybrid_as_an_expression(self, name_translations, name): + city = self.City(name_translations=name_translations) + self.session.add(city) + self.session.commit() + + assert self.session.query(self.City.name).scalar() == name + + def test_dynamic_locale(self): + translation_hybrid = TranslationHybrid( + lambda obj: obj.locale, + 'fi' + ) + + class Article(self.Base): + __tablename__ = 'article' + id = sa.Column(sa.Integer, primary_key=True) + name_translations = sa.Column(HSTORE) + name = translation_hybrid(name_translations) + locale = sa.Column(sa.String) + + assert ( + 'coalesce(article.name_translations -> article.locale' + in str(Article.name) + ) + + def test_locales_casted_only_in_compilation_phase(self): + class LocaleGetter(object): + def current_locale(self): + return lambda obj: obj.locale + + flexmock(LocaleGetter).should_receive('current_locale').never() + translation_hybrid = TranslationHybrid( + LocaleGetter().current_locale, + 'fi' + ) + + class Article(self.Base): + __tablename__ = 'article' + id = sa.Column(sa.Integer, primary_key=True) + name_translations = sa.Column(HSTORE) + name = translation_hybrid(name_translations) + locale = sa.Column(sa.String) + + Article.name diff --git a/python-sqlalchemy-utils/tests/types/__init__.py b/python-sqlalchemy-utils/tests/types/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/python-sqlalchemy-utils/tests/types/test_arrow.py b/python-sqlalchemy-utils/tests/types/test_arrow.py new file mode 100644 index 0000000..f3ca975 --- /dev/null +++ b/python-sqlalchemy-utils/tests/types/test_arrow.py @@ -0,0 +1,52 @@ +from datetime import datetime + +import sqlalchemy as sa +from pytest import mark + +from sqlalchemy_utils.types import arrow +from tests import TestCase + + +@mark.skipif('arrow.arrow is None') +class TestArrowDateTimeType(TestCase): + def create_models(self): + class Article(self.Base): + __tablename__ = 'article' + id = sa.Column(sa.Integer, primary_key=True) + created_at = sa.Column(arrow.ArrowType) + + self.Article = Article + + def test_parameter_processing(self): + article = self.Article( + created_at=arrow.arrow.get(datetime(2000, 11, 1)) + ) + + self.session.add(article) + self.session.commit() + + article = self.session.query(self.Article).first() + assert article.created_at.datetime + + def test_string_coercion(self): + article = self.Article( + created_at='1367900664' + ) + assert article.created_at.year == 2013 + + def test_utc(self): + time = arrow.arrow.utcnow() + article = self.Article(created_at=time) + self.session.add(article) + assert article.created_at == time + self.session.commit() + assert article.created_at == time + + def test_other_tz(self): + time = arrow.arrow.utcnow() + local = time.to('US/Pacific') + article = self.Article(created_at=local) + self.session.add(article) + assert article.created_at == time == local + self.session.commit() + assert article.created_at == time diff --git a/python-sqlalchemy-utils/tests/types/test_choice.py b/python-sqlalchemy-utils/tests/types/test_choice.py new file mode 100644 index 0000000..831f0d6 --- /dev/null +++ b/python-sqlalchemy-utils/tests/types/test_choice.py @@ -0,0 +1,181 @@ +import sqlalchemy as sa +from flexmock import flexmock +from pytest import mark, raises + +from sqlalchemy_utils import Choice, ChoiceType, ImproperlyConfigured +from sqlalchemy_utils.types.choice import Enum +from tests import TestCase + + +class TestChoice(object): + def test_equality_operator(self): + assert Choice(1, 1) == 1 + assert 1 == Choice(1, 1) + assert Choice(1, 1) == Choice(1, 1) + + def test_non_equality_operator(self): + assert Choice(1, 1) != 2 + assert not (Choice(1, 1) != 1) + + +class TestChoiceType(TestCase): + def create_models(self): + class User(self.Base): + TYPES = [ + ('admin', 'Admin'), + ('regular-user', 'Regular user') + ] + + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + type = sa.Column(ChoiceType(TYPES)) + + def __repr__(self): + return 'User(%r)' % self.id + + self.User = User + + def test_python_type(self): + type_ = self.User.__table__.c.type.type + assert type_.python_type + + def test_string_processing(self): + flexmock(ChoiceType).should_receive('_coerce').and_return( + u'admin' + ) + user = self.User( + type=u'admin' + ) + + self.session.add(user) + self.session.commit() + + user = self.session.query(self.User).first() + assert user.type.value == u'Admin' + + def test_parameter_processing(self): + user = self.User( + type=u'admin' + ) + + self.session.add(user) + self.session.commit() + + user = self.session.query(self.User).first() + assert user.type.value == u'Admin' + + def test_scalar_attributes_get_coerced_to_objects(self): + user = self.User(type=u'admin') + + assert isinstance(user.type, Choice) + + def test_throws_exception_if_no_choices_given(self): + with raises(ImproperlyConfigured): + ChoiceType([]) + + +class TestChoiceTypeWithCustomUnderlyingType(TestCase): + def test_init_type(self): + type_ = ChoiceType([(1, u'something')], impl=sa.Integer) + assert type_.impl == sa.Integer + + +@mark.skipif('Enum is None') +class TestEnumType(TestCase): + def create_models(self): + class OrderStatus(Enum): + unpaid = 0 + paid = 1 + + class Order(self.Base): + __tablename__ = 'order' + id_ = sa.Column(sa.Integer, primary_key=True) + status = sa.Column( + ChoiceType(OrderStatus, impl=sa.Integer()), + default=OrderStatus.unpaid, + ) + + def __repr__(self): + return 'Order(%r, %r)' % (self.id_, self.status) + + class OrderNullable(self.Base): + __tablename__ = 'order_nullable' + id_ = sa.Column(sa.Integer, primary_key=True) + status = sa.Column( + ChoiceType(OrderStatus, impl=sa.Integer()), + nullable=True, + ) + + self.OrderStatus = OrderStatus + self.Order = Order + self.OrderNullable = OrderNullable + + def test_parameter_initialization(self): + order = self.Order() + + self.session.add(order) + self.session.commit() + + order = self.session.query(self.Order).first() + assert order.status is self.OrderStatus.unpaid + assert order.status.value == 0 + + def test_setting_by_value(self): + order = self.Order() + order.status = 1 + + self.session.add(order) + self.session.commit() + + order = self.session.query(self.Order).first() + assert order.status is self.OrderStatus.paid + + def test_setting_by_enum(self): + order = self.Order() + order.status = self.OrderStatus.paid + + self.session.add(order) + self.session.commit() + + order = self.session.query(self.Order).first() + assert order.status is self.OrderStatus.paid + + def test_setting_value_that_resolves_to_none(self): + order = self.Order() + order.status = 0 + + self.session.add(order) + self.session.commit() + + order = self.session.query(self.Order).first() + assert order.status is self.OrderStatus.unpaid + + def test_setting_to_wrong_enum_raises_valueerror(self): + class WrongEnum(Enum): + foo = 0 + bar = 1 + + order = self.Order() + + with raises(ValueError): + order.status = WrongEnum.foo + + def test_setting_to_uncoerceable_type_raises_valueerror(self): + order = self.Order() + with raises(ValueError): + order.status = 'Bad value' + + def test_order_nullable_stores_none(self): + # With nullable=False as in `Order`, a `None` value is always + # converted to the default value, unless we explicitly set it to + # sqlalchemy.sql.null(), so we use this class to test our ability + # to set and retrive `None`. + order_nullable = self.OrderNullable() + assert order_nullable.status is None + + order_nullable.status = None + + self.session.add(order_nullable) + self.session.commit() + + assert order_nullable.status is None diff --git a/python-sqlalchemy-utils/tests/types/test_color.py b/python-sqlalchemy-utils/tests/types/test_color.py new file mode 100644 index 0000000..aa1d60c --- /dev/null +++ b/python-sqlalchemy-utils/tests/types/test_color.py @@ -0,0 +1,56 @@ +import sqlalchemy as sa +from flexmock import flexmock +from pytest import mark + +from sqlalchemy_utils import ColorType, types # noqa +from tests import TestCase + + +@mark.skipif('types.color.python_colour_type is None') +class TestColorType(TestCase): + def create_models(self): + class Document(self.Base): + __tablename__ = 'document' + id = sa.Column(sa.Integer, primary_key=True) + bg_color = sa.Column(ColorType) + + def __repr__(self): + return 'Document(%r)' % self.id + + self.Document = Document + + def test_string_parameter_processing(self): + from colour import Color + + flexmock(ColorType).should_receive('_coerce').and_return( + u'white' + ) + document = self.Document( + bg_color='white' + ) + + self.session.add(document) + self.session.commit() + + document = self.session.query(self.Document).first() + assert document.bg_color.hex == Color(u'white').hex + + def test_color_parameter_processing(self): + from colour import Color + + document = self.Document( + bg_color=Color(u'white') + ) + + self.session.add(document) + self.session.commit() + + document = self.session.query(self.Document).first() + assert document.bg_color.hex == Color(u'white').hex + + def test_scalar_attributes_get_coerced_to_objects(self): + from colour import Color + + document = self.Document(bg_color='white') + + assert isinstance(document.bg_color, Color) diff --git a/python-sqlalchemy-utils/tests/types/test_composite.py b/python-sqlalchemy-utils/tests/types/test_composite.py new file mode 100644 index 0000000..d748606 --- /dev/null +++ b/python-sqlalchemy-utils/tests/types/test_composite.py @@ -0,0 +1,310 @@ +import sqlalchemy as sa +from pytest import mark +from sqlalchemy import create_engine +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import sessionmaker + +from sqlalchemy_utils import ( + CompositeArray, + CompositeType, + Currency, + CurrencyType, + i18n, + NumericRangeType, + register_composites, + remove_composite_listeners +) +from sqlalchemy_utils.types import pg_composite +from sqlalchemy_utils.types.range import intervals +from tests import TestCase + + +class TestCompositeTypeWithRegularTypes(TestCase): + dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' + + def create_models(self): + class Account(self.Base): + __tablename__ = 'account' + id = sa.Column(sa.Integer, primary_key=True) + balance = sa.Column( + CompositeType( + 'money_type', + [ + sa.Column('currency', sa.String), + sa.Column('amount', sa.Integer) + ] + ) + ) + + self.Account = Account + + def test_parameter_processing(self): + account = self.Account( + balance=('USD', 15) + ) + + self.session.add(account) + self.session.commit() + + account = self.session.query(self.Account).first() + assert account.balance.currency == 'USD' + assert account.balance.amount == 15 + + +@mark.skipif('i18n.babel is None') +class TestCompositeTypeWithTypeDecorators(TestCase): + dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' + + def setup_method(self, method): + TestCase.setup_method(self, method) + i18n.get_locale = lambda: i18n.babel.Locale('en') + + def create_models(self): + class Account(self.Base): + __tablename__ = 'account' + id = sa.Column(sa.Integer, primary_key=True) + balance = sa.Column( + CompositeType( + 'money_type', + [ + sa.Column('currency', CurrencyType), + sa.Column('amount', sa.Integer) + ] + ) + ) + + self.Account = Account + + def test_result_set_processing(self): + account = self.Account( + balance=('USD', 15) + ) + + self.session.add(account) + self.session.commit() + + account = self.session.query(self.Account).first() + assert account.balance.currency == Currency('USD') + assert account.balance.amount == 15 + + def test_parameter_processing(self): + account = self.Account( + balance=(Currency('USD'), 15) + ) + + self.session.add(account) + self.session.commit() + + account = self.session.query(self.Account).first() + assert account.balance.currency == Currency('USD') + assert account.balance.amount == 15 + + +@mark.skipif('i18n.babel is None') +class TestCompositeTypeInsideArray(TestCase): + dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' + + def setup_method(self, method): + self.type = CompositeType( + 'money_type', + [ + sa.Column('currency', CurrencyType), + sa.Column('amount', sa.Integer) + ] + ) + + TestCase.setup_method(self, method) + i18n.get_locale = lambda: i18n.babel.Locale('en') + + def create_models(self): + class Account(self.Base): + __tablename__ = 'account' + id = sa.Column(sa.Integer, primary_key=True) + balances = sa.Column( + CompositeArray(self.type) + ) + + self.Account = Account + + def test_parameter_processing(self): + account = self.Account( + balances=[ + self.type.type_cls(Currency('USD'), 15), + self.type.type_cls(Currency('AUD'), 20) + ] + ) + + self.session.add(account) + self.session.commit() + + account = self.session.query(self.Account).first() + assert account.balances[0].currency == Currency('USD') + assert account.balances[0].amount == 15 + assert account.balances[1].currency == Currency('AUD') + assert account.balances[1].amount == 20 + + +@mark.skipif('intervals is None') +class TestCompositeTypeWithRangeTypeInsideArray(TestCase): + dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' + + def setup_method(self, method): + self.type = CompositeType( + 'category', + [ + sa.Column('scale', NumericRangeType), + sa.Column('name', sa.String) + ] + ) + + TestCase.setup_method(self, method) + + def create_models(self): + class Account(self.Base): + __tablename__ = 'account' + id = sa.Column(sa.Integer, primary_key=True) + categories = sa.Column( + CompositeArray(self.type) + ) + + self.Account = Account + + def test_parameter_processing_with_named_tuple(self): + account = self.Account( + categories=[ + self.type.type_cls( + intervals.DecimalInterval([15, 18]), + 'bad' + ), + self.type.type_cls( + intervals.DecimalInterval([18, 20]), + 'good' + ) + ] + ) + + self.session.add(account) + self.session.commit() + + account = self.session.query(self.Account).first() + assert ( + account.categories[0].scale == intervals.DecimalInterval([15, 18]) + ) + assert account.categories[0].name == 'bad' + assert ( + account.categories[1].scale == intervals.DecimalInterval([18, 20]) + ) + assert account.categories[1].name == 'good' + + def test_parameter_processing_with_tuple(self): + account = self.Account( + categories=[ + (intervals.DecimalInterval([15, 18]), 'bad'), + (intervals.DecimalInterval([18, 20]), 'good') + ] + ) + + self.session.add(account) + self.session.commit() + + account = self.session.query(self.Account).first() + assert ( + account.categories[0].scale == intervals.DecimalInterval([15, 18]) + ) + assert account.categories[0].name == 'bad' + assert ( + account.categories[1].scale == intervals.DecimalInterval([18, 20]) + ) + assert account.categories[1].name == 'good' + + def test_parameter_processing_with_nulls_as_composite_fields(self): + account = self.Account( + categories=[ + (None, 'bad'), + (intervals.DecimalInterval([18, 20]), None) + ] + ) + self.session.add(account) + self.session.commit() + assert account.categories[0].scale is None + assert account.categories[0].name == 'bad' + assert ( + account.categories[1].scale == intervals.DecimalInterval([18, 20]) + ) + assert account.categories[1].name is None + + def test_parameter_processing_with_nulls_as_composites(self): + account = self.Account( + categories=[ + (None, None), + None + ] + ) + self.session.add(account) + self.session.commit() + assert account.categories[0].scale is None + assert account.categories[0].name is None + assert account.categories[1] is None + + +class TestCompositeTypeWhenTypeAlreadyExistsInDatabase(TestCase): + dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' + + def setup_method(self, method): + self.engine = create_engine(self.dns) + self.engine.echo = True + self.connection = self.engine.connect() + self.Base = declarative_base() + pg_composite.registered_composites = {} + + self.type = CompositeType( + 'money_type', + [ + sa.Column('currency', sa.String), + sa.Column('amount', sa.Integer) + ] + ) + + self.create_models() + sa.orm.configure_mappers() + + Session = sessionmaker(bind=self.connection) + self.session = Session() + self.session.execute( + "CREATE TYPE money_type AS (currency VARCHAR, amount INTEGER)" + ) + self.session.execute( + """CREATE TABLE account ( + id SERIAL, balance MONEY_TYPE, PRIMARY KEY(id) + )""" + ) + register_composites(self.connection) + + def teardown_method(self, method): + self.session.execute('DROP TABLE account') + self.session.execute('DROP TYPE money_type') + self.session.commit() + self.session.close_all() + self.connection.close() + remove_composite_listeners() + self.engine.dispose() + + def create_models(self): + class Account(self.Base): + __tablename__ = 'account' + id = sa.Column(sa.Integer, primary_key=True) + balance = sa.Column(self.type) + + self.Account = Account + + def test_parameter_processing(self): + account = self.Account( + balance=('USD', 15), + ) + + self.session.add(account) + self.session.commit() + + account = self.session.query(self.Account).first() + assert account.balance.currency == 'USD' + assert account.balance.amount == 15 diff --git a/python-sqlalchemy-utils/tests/types/test_country.py b/python-sqlalchemy-utils/tests/types/test_country.py new file mode 100644 index 0000000..59bae9d --- /dev/null +++ b/python-sqlalchemy-utils/tests/types/test_country.py @@ -0,0 +1,34 @@ +import sqlalchemy as sa +from pytest import mark + +from sqlalchemy_utils import Country, CountryType, i18n # noqa +from tests import TestCase + + +@mark.skipif('i18n.babel is None') +class TestCountryType(TestCase): + def create_models(self): + class User(self.Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + country = sa.Column(CountryType) + + def __repr__(self): + return 'User(%r)' % self.id + + self.User = User + + def test_parameter_processing(self): + user = self.User( + country=Country(u'FI') + ) + + self.session.add(user) + self.session.commit() + + user = self.session.query(self.User).first() + assert user.country.name == u'Finland' + + def test_scalar_attributes_get_coerced_to_objects(self): + user = self.User(country='FI') + assert isinstance(user.country, Country) diff --git a/python-sqlalchemy-utils/tests/types/test_currency.py b/python-sqlalchemy-utils/tests/types/test_currency.py new file mode 100644 index 0000000..caa7a1d --- /dev/null +++ b/python-sqlalchemy-utils/tests/types/test_currency.py @@ -0,0 +1,39 @@ +# -*- coding: utf-8 -*- +import sqlalchemy as sa +from pytest import mark + +from sqlalchemy_utils import Currency, CurrencyType, i18n +from tests import TestCase + + +@mark.skipif('i18n.babel is None') +class TestCurrencyType(TestCase): + def setup_method(self, method): + TestCase.setup_method(self, method) + i18n.get_locale = lambda: i18n.babel.Locale('en') + + def create_models(self): + class User(self.Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + currency = sa.Column(CurrencyType) + + def __repr__(self): + return 'User(%r)' % self.id + + self.User = User + + def test_parameter_processing(self): + user = self.User( + currency=Currency('USD') + ) + + self.session.add(user) + self.session.commit() + + user = self.session.query(self.User).first() + assert user.currency.name == u'US Dollar' + + def test_scalar_attributes_get_coerced_to_objects(self): + user = self.User(currency='USD') + assert isinstance(user.currency, Currency) diff --git a/python-sqlalchemy-utils/tests/types/test_date_range.py b/python-sqlalchemy-utils/tests/types/test_date_range.py new file mode 100644 index 0000000..238e30f --- /dev/null +++ b/python-sqlalchemy-utils/tests/types/test_date_range.py @@ -0,0 +1,100 @@ +from datetime import datetime, timedelta + +import sqlalchemy as sa +from pytest import mark + +from sqlalchemy_utils import DateRangeType +from tests import TestCase + +intervals = None +inf = 0 +try: + import intervals + from infinity import inf +except ImportError: + pass + + +@mark.skipif('intervals is None') +class DateRangeTestCase(TestCase): + def create_models(self): + class Booking(self.Base): + __tablename__ = 'booking' + id = sa.Column(sa.Integer, primary_key=True) + during = sa.Column(DateRangeType) + + self.Booking = Booking + + def create_booking(self, date_range): + booking = self.Booking( + during=date_range + ) + + self.session.add(booking) + self.session.commit() + return self.session.query(self.Booking).first() + + def test_nullify_range(self): + booking = self.create_booking(None) + assert booking.during is None + + @mark.parametrize( + ('date_range'), + ( + [datetime(2015, 1, 1).date(), datetime(2015, 1, 3).date()], + [datetime(2015, 1, 1).date(), inf], + [-inf, datetime(2015, 1, 1).date()] + ) + ) + def test_save_date_range(self, date_range): + booking = self.create_booking(date_range) + assert booking.during.lower == date_range[0] + assert booking.during.upper == date_range[1] + + def test_nullify_date_range(self): + booking = self.Booking( + during=intervals.DateInterval( + [datetime(2015, 1, 1).date(), datetime(2015, 1, 3).date()] + ) + ) + + self.session.add(booking) + self.session.commit() + + booking = self.session.query(self.Booking).first() + booking.during = None + self.session.commit() + + booking = self.session.query(self.Booking).first() + assert booking.during is None + + def test_integer_coercion(self): + booking = self.Booking(during=datetime(2015, 1, 1).date()) + assert booking.during.lower == datetime(2015, 1, 1).date() + assert booking.during.upper == datetime(2015, 1, 1).date() + + +class TestDateRangeOnPostgres(DateRangeTestCase): + dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' + + @mark.parametrize( + ('date_range', 'length'), + ( + ( + [datetime(2015, 1, 1).date(), datetime(2015, 1, 3).date()], + timedelta(days=2) + ), + ( + [datetime(2015, 1, 1).date(), datetime(2015, 1, 1).date()], + timedelta(days=0) + ), + ([-inf, datetime(2015, 1, 1).date()], None), + ([datetime(2015, 1, 1).date(), inf], None), + ) + ) + def test_length(self, date_range, length): + self.create_booking(date_range) + query = ( + self.session.query(self.Booking.during.length) + ) + assert query.scalar() == length diff --git a/python-sqlalchemy-utils/tests/types/test_datetime_range.py b/python-sqlalchemy-utils/tests/types/test_datetime_range.py new file mode 100644 index 0000000..b6bf731 --- /dev/null +++ b/python-sqlalchemy-utils/tests/types/test_datetime_range.py @@ -0,0 +1,100 @@ +from datetime import datetime, timedelta + +import sqlalchemy as sa +from pytest import mark + +from sqlalchemy_utils import DateTimeRangeType +from tests import TestCase + +intervals = None +inf = 0 +try: + import intervals + from infinity import inf +except ImportError: + pass + + +@mark.skipif('intervals is None') +class DateRangeTestCase(TestCase): + def create_models(self): + class Booking(self.Base): + __tablename__ = 'booking' + id = sa.Column(sa.Integer, primary_key=True) + during = sa.Column(DateTimeRangeType) + + self.Booking = Booking + + def create_booking(self, date_range): + booking = self.Booking( + during=date_range + ) + + self.session.add(booking) + self.session.commit() + return self.session.query(self.Booking).first() + + def test_nullify_range(self): + booking = self.create_booking(None) + assert booking.during is None + + @mark.parametrize( + ('date_range'), + ( + [datetime(2015, 1, 1), datetime(2015, 1, 3)], + [datetime(2015, 1, 1), inf], + [-inf, datetime(2015, 1, 1)] + ) + ) + def test_save_date_range(self, date_range): + booking = self.create_booking(date_range) + assert booking.during.lower == date_range[0] + assert booking.during.upper == date_range[1] + + def test_nullify_date_range(self): + booking = self.Booking( + during=intervals.DateInterval( + [datetime(2015, 1, 1), datetime(2015, 1, 3)] + ) + ) + + self.session.add(booking) + self.session.commit() + + booking = self.session.query(self.Booking).first() + booking.during = None + self.session.commit() + + booking = self.session.query(self.Booking).first() + assert booking.during is None + + def test_integer_coercion(self): + booking = self.Booking(during=datetime(2015, 1, 1)) + assert booking.during.lower == datetime(2015, 1, 1) + assert booking.during.upper == datetime(2015, 1, 1) + + +class TestDateRangeOnPostgres(DateRangeTestCase): + dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' + + @mark.parametrize( + ('date_range', 'length'), + ( + ( + [datetime(2015, 1, 1), datetime(2015, 1, 3)], + timedelta(days=2) + ), + ( + [datetime(2015, 1, 1), datetime(2015, 1, 1)], + timedelta(days=0) + ), + ([-inf, datetime(2015, 1, 1)], None), + ([datetime(2015, 1, 1), inf], None), + ) + ) + def test_length(self, date_range, length): + self.create_booking(date_range) + query = ( + self.session.query(self.Booking.during.length) + ) + assert query.scalar() == length diff --git a/python-sqlalchemy-utils/tests/types/test_email.py b/python-sqlalchemy-utils/tests/types/test_email.py new file mode 100644 index 0000000..5b65f71 --- /dev/null +++ b/python-sqlalchemy-utils/tests/types/test_email.py @@ -0,0 +1,28 @@ +import sqlalchemy as sa + +from sqlalchemy_utils import EmailType +from tests import TestCase + + +class TestEmailType(TestCase): + def create_models(self): + class User(self.Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + email = sa.Column(EmailType) + + def __repr__(self): + return 'User(%r)' % self.id + + self.User = User + + def test_saves_email_as_lowercased(self): + user = self.User( + email=u'Someone@example.com' + ) + + self.session.add(user) + self.session.commit() + + user = self.session.query(self.User).first() + assert user.email == u'someone@example.com' diff --git a/python-sqlalchemy-utils/tests/types/test_encrypted.py b/python-sqlalchemy-utils/tests/types/test_encrypted.py new file mode 100644 index 0000000..a91e642 --- /dev/null +++ b/python-sqlalchemy-utils/tests/types/test_encrypted.py @@ -0,0 +1,244 @@ +from datetime import date, datetime, time + +import pytest +import sqlalchemy as sa +from pytest import mark + +from sqlalchemy_utils import ColorType, EncryptedType, PhoneNumberType +from sqlalchemy_utils.types.encrypted import AesEngine, FernetEngine +from tests import TestCase + +cryptography = None +try: + import cryptography # noqa +except ImportError: + pass + + +@mark.skipif('cryptography is None') +class EncryptedTypeTestCase(TestCase): + + @pytest.fixture(scope='function') + def user(self, request): + # set the values to the user object + self.user = self.User() + self.user.username = self.user_name + self.user.phone = self.user_phone + self.user.color = self.user_color + self.user.date = self.user_date + self.user.time = self.user_time + self.user.enum = self.user_enum + self.user.datetime = self.user_datetime + self.user.access_token = self.test_token + self.user.is_active = self.active + self.user.accounts_num = self.accounts_num + self.session.add(self.user) + self.session.commit() + + # register a finalizer to cleanup + def finalize(): + del self.user_name + del self.test_token + del self.active + del self.accounts_num + del self.test_key + del self.searched_user + + request.addfinalizer(finalize) + + return self.session.query(self.User).get(self.user.id) + + def generate_test_token(self): + import string + import random + token = '' + characters = string.ascii_letters + string.digits + for i in range(60): + token += ''.join(random.choice(characters)) + return token + + def create_models(self): + # set some test values + self.test_key = 'secretkey1234' + self.user_name = u'someone' + self.user_phone = u'(555) 555-5555' + self.user_color = u'#fff' + self.user_enum = 'One' + self.user_date = date(2010, 10, 2) + self.user_time = time(10, 12) + self.user_datetime = datetime(2010, 10, 2, 10, 12) + self.test_token = self.generate_test_token() + self.active = True + self.accounts_num = 2 + self.searched_user = None + + class User(self.Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + + username = sa.Column(EncryptedType( + sa.Unicode, + self.test_key, + self.__class__.encryption_engine) + ) + + access_token = sa.Column(EncryptedType( + sa.String, + self.test_key, + self.__class__.encryption_engine) + ) + + is_active = sa.Column(EncryptedType( + sa.Boolean, + self.test_key, + self.__class__.encryption_engine) + ) + + accounts_num = sa.Column(EncryptedType( + sa.Integer, + self.test_key, + self.__class__.encryption_engine) + ) + + phone = sa.Column(EncryptedType( + PhoneNumberType, + self.test_key, + self.__class__.encryption_engine) + ) + + color = sa.Column(EncryptedType( + ColorType, + self.test_key, + self.__class__.encryption_engine) + ) + + date = sa.Column(EncryptedType( + sa.Date, + self.test_key, + self.__class__.encryption_engine) + ) + + time = sa.Column(EncryptedType( + sa.Time, + self.test_key, + self.__class__.encryption_engine) + ) + + datetime = sa.Column(EncryptedType( + sa.DateTime, + self.test_key, + self.__class__.encryption_engine) + ) + + enum = sa.Column(EncryptedType( + sa.Enum('One', name='user_enum_t'), + self.test_key, + self.__class__.encryption_engine) + ) + + self.User = User + + class Team(self.Base): + __tablename__ = 'team' + id = sa.Column(sa.Integer, primary_key=True) + key = sa.Column(sa.String(50)) + name = sa.Column(EncryptedType( + sa.Unicode, + lambda: self._team_key, + self.__class__.encryption_engine) + ) + + self.Team = Team + + def test_unicode(self, user): + assert user.username == self.user_name + + def test_string(self, user): + assert user.access_token == self.test_token + + def test_boolean(self, user): + assert user.is_active == self.active + + def test_integer(self, user): + assert user.accounts_num == self.accounts_num + + def test_phone_number(self, user): + assert str(user.phone) == self.user_phone + + def test_color(self, user): + assert user.color.hex == self.user_color + + def test_date(self, user): + assert user.date == self.user_date + + def test_datetime(self, user): + assert user.datetime == self.user_datetime + + def test_time(self, user): + assert user.time == self.user_time + + def test_enum(self, user): + assert user.enum == self.user_enum + + def test_lookup_key(self): + # Add teams + self._team_key = 'one' + team = self.Team(key=self._team_key, name=u'One') + self.session.add(team) + self.session.commit() + team_1_id = team.id + + self._team_key = 'two' + team = self.Team(key=self._team_key) + team.name = u'Two' + self.session.add(team) + self.session.commit() + team_2_id = team.id + + # Lookup teams + self._team_key = self.session.query(self.Team.key).filter_by( + id=team_1_id + ).one()[0] + + team = self.session.query(self.Team).get(team_1_id) + + assert team.name == u'One' + + with pytest.raises(Exception): + self.session.query(self.Team).get(team_2_id) + + self.session.expunge_all() + + self._team_key = self.session.query(self.Team.key).filter_by( + id=team_2_id + ).one()[0] + + team = self.session.query(self.Team).get(team_2_id) + + assert team.name == u'Two' + + with pytest.raises(Exception): + self.session.query(self.Team).get(team_1_id) + + self.session.expunge_all() + + # Remove teams + self.session.query(self.Team).delete() + self.session.commit() + + +class TestAesEncryptedTypeTestcase(EncryptedTypeTestCase): + + encryption_engine = AesEngine + + def test_lookup_by_encrypted_string(self, user): + test = self.session.query(self.User).filter( + self.User.username == self.user_name + ).first() + + assert test.username == user.username + + +class TestFernetEncryptedTypeTestCase(EncryptedTypeTestCase): + + encryption_engine = FernetEngine diff --git a/python-sqlalchemy-utils/tests/types/test_int_range.py b/python-sqlalchemy-utils/tests/types/test_int_range.py new file mode 100644 index 0000000..c349496 --- /dev/null +++ b/python-sqlalchemy-utils/tests/types/test_int_range.py @@ -0,0 +1,314 @@ +import sqlalchemy as sa +from pytest import mark + +from sqlalchemy_utils import IntRangeType +from tests import TestCase + +intervals = None +inf = -1 +try: + import intervals + from infinity import inf +except ImportError: + pass + + +@mark.skipif('intervals is None') +class NumberRangeTestCase(TestCase): + def create_models(self): + class Building(self.Base): + __tablename__ = 'building' + id = sa.Column(sa.Integer, primary_key=True) + persons_at_night = sa.Column(IntRangeType) + + def __repr__(self): + return 'Building(%r)' % self.id + + self.Building = Building + + def create_building(self, number_range): + building = self.Building( + persons_at_night=number_range + ) + + self.session.add(building) + self.session.commit() + return self.session.query(self.Building).first() + + def test_nullify_range(self): + building = self.create_building(None) + assert building.persons_at_night is None + + def test_update_with_none(self): + interval = intervals.IntInterval('(,)') + building = self.create_building(interval) + building.persons_at_night = None + assert building.persons_at_night is None + self.session.commit() + assert building.persons_at_night is None + + @mark.parametrize( + 'number_range', + ( + [1, 3], + '1 - 3', + ) + ) + def test_save_number_range(self, number_range): + building = self.create_building(number_range) + assert building.persons_at_night.lower == 1 + assert building.persons_at_night.upper == 3 + + def test_infinite_upper_bound(self): + building = self.create_building([1, inf]) + assert building.persons_at_night.lower == 1 + assert building.persons_at_night.upper == inf + + def test_infinite_lower_bound(self): + building = self.create_building([-inf, 1]) + assert building.persons_at_night.lower == -inf + assert building.persons_at_night.upper == 1 + + def test_nullify_number_range(self): + building = self.Building( + persons_at_night=intervals.IntInterval([1, 3]) + ) + + self.session.add(building) + self.session.commit() + + building = self.session.query(self.Building).first() + building.persons_at_night = None + self.session.commit() + + building = self.session.query(self.Building).first() + assert building.persons_at_night is None + + def test_string_coercion(self): + building = self.Building(persons_at_night='[12, 18]') + assert isinstance(building.persons_at_night, intervals.IntInterval) + + def test_integer_coercion(self): + building = self.Building(persons_at_night=15) + assert building.persons_at_night.lower == 15 + assert building.persons_at_night.upper == 15 + + +class TestIntRangeTypeOnPostgres(NumberRangeTestCase): + dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' + + @mark.parametrize( + 'number_range', + ( + [1, 3], + '1 - 3', + (0, 4) + ) + ) + def test_eq_operator(self, number_range): + self.create_building([1, 3]) + query = ( + self.session.query(self.Building) + .filter(self.Building.persons_at_night == number_range) + ) + assert query.count() + + @mark.parametrize( + ('number_range', 'length'), + ( + ([1, 3], 2), + ([1, 1], 0), + ([-1, 1], 2), + ([-inf, 1], None), + ([0, inf], None), + ([0, 0], 0), + ([-3, -1], 2) + ) + ) + def test_length(self, number_range, length): + self.create_building(number_range) + query = ( + self.session.query(self.Building.persons_at_night.length) + ) + assert query.scalar() == length + + @mark.parametrize( + 'number_range', + ( + [[1, 3]], + ['1 - 3'], + [(0, 4)], + ) + ) + def test_in_operator(self, number_range): + self.create_building([1, 3]) + query = ( + self.session.query(self.Building) + .filter(self.Building.persons_at_night.in_(number_range)) + ) + assert query.count() + + @mark.parametrize( + 'number_range', + ( + [1, 3], + '1 - 3', + (0, 4), + ) + ) + def test_rshift_operator(self, number_range): + self.create_building([5, 6]) + query = ( + self.session.query(self.Building) + .filter(self.Building.persons_at_night >> number_range) + ) + assert query.count() + + @mark.parametrize( + 'number_range', + ( + [1, 3], + '1 - 3', + (0, 4), + ) + ) + def test_lshift_operator(self, number_range): + self.create_building([-1, 0]) + query = ( + self.session.query(self.Building) + .filter(self.Building.persons_at_night << number_range) + ) + assert query.count() + + @mark.parametrize( + 'number_range', + ( + [1, 3], + '1 - 3', + (1, 3), + 2 + ) + ) + def test_contains_operator(self, number_range): + self.create_building([1, 3]) + query = ( + self.session.query(self.Building) + .filter(self.Building.persons_at_night.contains(number_range)) + ) + assert query.count() + + @mark.parametrize( + 'number_range', + ( + [1, 3], + '1 - 3', + (0, 8), + (-inf, inf) + ) + ) + def test_contained_by_operator(self, number_range): + self.create_building([1, 3]) + query = ( + self.session.query(self.Building) + .filter(self.Building.persons_at_night.contained_by(number_range)) + ) + assert query.count() + + @mark.parametrize( + 'number_range', + ( + [2, 5], + '0 - 2', + 0 + ) + ) + def test_not_in_operator(self, number_range): + self.create_building([1, 3]) + query = ( + self.session.query(self.Building) + .filter(~ self.Building.persons_at_night.in_([number_range])) + ) + assert query.count() + + def test_eq_with_query_arg(self): + self.create_building([1, 3]) + query = ( + self.session.query(self.Building) + .filter( + self.Building.persons_at_night == + self.session.query( + self.Building.persons_at_night) + ).order_by(self.Building.persons_at_night).limit(1) + ) + assert query.count() + + @mark.parametrize( + 'number_range', + ( + [1, 2], + '1 - 3', + (0, 4), + [0, 3], + 0, + 1, + ) + ) + def test_ge_operator(self, number_range): + self.create_building([1, 3]) + query = ( + self.session.query(self.Building) + .filter(self.Building.persons_at_night >= number_range) + ) + assert query.count() + + @mark.parametrize( + 'number_range', + ( + [0, 2], + 0, + [-inf, 2] + ) + ) + def test_gt_operator(self, number_range): + self.create_building([1, 3]) + query = ( + self.session.query(self.Building) + .filter(self.Building.persons_at_night > number_range) + ) + assert query.count() + + @mark.parametrize( + 'number_range', + ( + [1, 4], + 4, + [2, inf] + ) + ) + def test_le_operator(self, number_range): + self.create_building([1, 3]) + query = ( + self.session.query(self.Building) + .filter(self.Building.persons_at_night <= number_range) + ) + assert query.count() + + @mark.parametrize( + 'number_range', + ( + [2, 4], + 4, + [1, inf] + ) + ) + def test_lt_operator(self, number_range): + self.create_building([1, 3]) + query = ( + self.session.query(self.Building) + .filter(self.Building.persons_at_night < number_range) + ) + assert query.count() + + +class TestNumberRangeTypeOnSqlite(NumberRangeTestCase): + pass diff --git a/python-sqlalchemy-utils/tests/types/test_ip_address.py b/python-sqlalchemy-utils/tests/types/test_ip_address.py new file mode 100644 index 0000000..3b5a4e5 --- /dev/null +++ b/python-sqlalchemy-utils/tests/types/test_ip_address.py @@ -0,0 +1,31 @@ +import six +import sqlalchemy as sa +from pytest import mark + +from sqlalchemy_utils.types import ip_address +from tests import TestCase + + +@mark.skipif('ip_address.ip_address is None') +class TestIPAddressType(TestCase): + def create_models(self): + class Visitor(self.Base): + __tablename__ = 'document' + id = sa.Column(sa.Integer, primary_key=True) + ip_address = sa.Column(ip_address.IPAddressType) + + def __repr__(self): + return 'Visitor(%r)' % self.id + + self.Visitor = Visitor + + def test_parameter_processing(self): + visitor = self.Visitor( + ip_address=u'111.111.111.111' + ) + + self.session.add(visitor) + self.session.commit() + + visitor = self.session.query(self.Visitor).first() + assert six.text_type(visitor.ip_address) == u'111.111.111.111' diff --git a/python-sqlalchemy-utils/tests/types/test_json.py b/python-sqlalchemy-utils/tests/types/test_json.py new file mode 100644 index 0000000..8305a36 --- /dev/null +++ b/python-sqlalchemy-utils/tests/types/test_json.py @@ -0,0 +1,59 @@ +# -*- coding: utf-8 -*- +import sqlalchemy as sa +from pytest import mark + +from sqlalchemy_utils.types import json +from tests import TestCase + + +class JSONTestCase(TestCase): + def create_models(self): + class Document(self.Base): + __tablename__ = 'document' + id = sa.Column(sa.Integer, primary_key=True) + json = sa.Column(json.JSONType) + + self.Document = Document + + def test_list(self): + document = self.Document( + json=[1, 2, 3] + ) + + self.session.add(document) + self.session.commit() + + document = self.session.query(self.Document).first() + assert document.json == [1, 2, 3] + + def test_parameter_processing(self): + document = self.Document( + json={'something': 12} + ) + + self.session.add(document) + self.session.commit() + + document = self.session.query(self.Document).first() + assert document.json == {'something': 12} + + def test_non_ascii_chars(self): + document = self.Document( + json={'something': u'äääööö'} + ) + + self.session.add(document) + self.session.commit() + + document = self.session.query(self.Document).first() + assert document.json == {'something': u'äääööö'} + + +@mark.skipif('json.json is None') +class TestSqliteJSONType(JSONTestCase): + pass + + +@mark.skipif('json.json is None') +class TestPostgresJSONType(JSONTestCase): + dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' diff --git a/python-sqlalchemy-utils/tests/types/test_locale.py b/python-sqlalchemy-utils/tests/types/test_locale.py new file mode 100644 index 0000000..573aa3d --- /dev/null +++ b/python-sqlalchemy-utils/tests/types/test_locale.py @@ -0,0 +1,38 @@ +import sqlalchemy as sa +from pytest import mark, raises + +from sqlalchemy_utils.types import locale +from tests import TestCase + + +@mark.skipif('locale.babel is None') +class TestLocaleType(TestCase): + def create_models(self): + class User(self.Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + locale = sa.Column(locale.LocaleType) + + def __repr__(self): + return 'User(%r)' % self.id + + self.User = User + + def test_parameter_processing(self): + user = self.User( + locale=locale.babel.Locale(u'fi') + ) + + self.session.add(user) + self.session.commit() + + user = self.session.query(self.User).first() + + def test_scalar_attributes_get_coerced_to_objects(self): + user = self.User(locale='en_US') + + assert isinstance(user.locale, locale.babel.Locale) + + def test_unknown_locale_throws_exception(self): + with raises(locale.babel.UnknownLocaleError): + self.User(locale=u'unknown') diff --git a/python-sqlalchemy-utils/tests/types/test_numeric_range.py b/python-sqlalchemy-utils/tests/types/test_numeric_range.py new file mode 100644 index 0000000..ad40acb --- /dev/null +++ b/python-sqlalchemy-utils/tests/types/test_numeric_range.py @@ -0,0 +1,142 @@ +from decimal import Decimal + +import sqlalchemy as sa +from pytest import mark + +from sqlalchemy_utils import NumericRangeType +from tests import TestCase + +intervals = None +inf = 0 +try: + import intervals + from infinity import inf +except ImportError: + pass + + +@mark.skipif('intervals is None') +class NumericRangeTestCase(TestCase): + def create_models(self): + class Car(self.Base): + __tablename__ = 'car' + id = sa.Column(sa.Integer, primary_key=True) + price_range = sa.Column(NumericRangeType) + + self.Car = Car + + def create_car(self, number_range): + car = self.Car( + price_range=number_range + ) + + self.session.add(car) + self.session.commit() + return self.session.query(self.Car).first() + + def test_nullify_range(self): + car = self.create_car(None) + assert car.price_range is None + + @mark.parametrize( + 'number_range', + ( + [1, 3], + '1 - 3', + ) + ) + def test_save_number_range(self, number_range): + car = self.create_car(number_range) + assert car.price_range.lower == 1 + assert car.price_range.upper == 3 + + def test_infinite_upper_bound(self): + car = self.create_car([1, inf]) + assert car.price_range.lower == 1 + assert car.price_range.upper == inf + + def test_infinite_lower_bound(self): + car = self.create_car([-inf, 1]) + assert car.price_range.lower == -inf + assert car.price_range.upper == 1 + + def test_nullify_number_range(self): + car = self.Car( + price_range=intervals.DecimalInterval([1, 3]) + ) + + self.session.add(car) + self.session.commit() + + car = self.session.query(self.Car).first() + car.price_range = None + self.session.commit() + + car = self.session.query(self.Car).first() + assert car.price_range is None + + def test_string_coercion(self): + car = self.Car(price_range='[12, 18]') + assert isinstance(car.price_range, intervals.DecimalInterval) + + def test_integer_coercion(self): + car = self.Car(price_range=15) + assert car.price_range.lower == 15 + assert car.price_range.upper == 15 + + +class TestNumericRangeOnPostgres(NumericRangeTestCase): + dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' + + @mark.parametrize( + ('number_range', 'length'), + ( + ([1, 3], 2), + ([1, 1], 0), + ([-1, 1], 2), + ([-inf, 1], None), + ([0, inf], None), + ([0, 0], 0), + ([-3, -1], 2) + ) + ) + def test_length(self, number_range, length): + self.create_car(number_range) + query = ( + self.session.query(self.Car.price_range.length) + ) + assert query.scalar() == length + + +@mark.skipif('intervals is None') +class TestNumericRangeWithStep(TestCase): + def create_models(self): + class Car(self.Base): + __tablename__ = 'car' + id = sa.Column(sa.Integer, primary_key=True) + price_range = sa.Column(NumericRangeType(step=Decimal('0.5'))) + + self.Car = Car + + def create_car(self, number_range): + car = self.Car( + price_range=number_range + ) + + self.session.add(car) + self.session.commit() + return self.session.query(self.Car).first() + + def test_passes_step_argument_to_interval_object(self): + car = self.create_car([Decimal('0.2'), Decimal('0.8')]) + assert car.price_range.lower == Decimal('0') + assert car.price_range.upper == Decimal('1') + assert car.price_range.step == Decimal('0.5') + + def test_passes_step_fetched_objects(self): + self.create_car([Decimal('0.2'), Decimal('0.8')]) + self.session.expunge_all() + car = self.session.query(self.Car).first() + assert car.price_range.lower == Decimal('0') + assert car.price_range.upper == Decimal('1') + assert car.price_range.step == Decimal('0.5') diff --git a/python-sqlalchemy-utils/tests/types/test_password.py b/python-sqlalchemy-utils/tests/types/test_password.py new file mode 100644 index 0000000..474ad94 --- /dev/null +++ b/python-sqlalchemy-utils/tests/types/test_password.py @@ -0,0 +1,174 @@ +import sqlalchemy as sa +from pytest import mark +from sqlalchemy import inspect + +from sqlalchemy_utils import Password, PasswordType, types # noqa +from tests import TestCase + + +@mark.skipif('types.password.passlib is None') +class TestPasswordType(TestCase): + def create_models(self): + class User(self.Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + password = sa.Column(PasswordType( + schemes=[ + 'pbkdf2_sha512', + 'pbkdf2_sha256', + 'md5_crypt', + 'hex_md5' + ], + + deprecated=['md5_crypt', 'hex_md5'] + )) + + def __repr__(self): + return 'User(%r)' % self.id + + self.User = User + + def test_encrypt(self): + """Should encrypt the password on setting the attribute.""" + obj = self.User() + obj.password = b'b' + + assert obj.password.hash != 'b' + assert obj.password.hash.startswith(b'$pbkdf2-sha512$') + + def test_check(self): + """ + Should be able to compare the plaintext against the + encrypted form. + """ + obj = self.User() + obj.password = 'b' + + assert obj.password == 'b' + assert obj.password != 'a' + + self.session.add(obj) + self.session.commit() + + obj = self.session.query(self.User).get(obj.id) + + assert obj.password == b'b' + assert obj.password != 'a' + + def test_check_and_update(self): + """ + Should be able to compare the plaintext against a deprecated + encrypted form and have it auto-update to the preferred version. + """ + + from passlib.hash import md5_crypt + + obj = self.User() + obj.password = Password(md5_crypt.encrypt('b')) + + assert obj.password.hash.decode('utf8').startswith('$1$') + assert obj.password == 'b' + assert obj.password.hash.decode('utf8').startswith('$pbkdf2-sha512$') + + def test_auto_column_length(self): + """Should derive the correct column length from the specified schemes. + """ + + from passlib.hash import pbkdf2_sha512 + + kind = inspect(self.User).c.password.type + + # name + rounds + salt + hash + ($ * 4) of largest hash + expected_length = len(pbkdf2_sha512.name) + expected_length += len(str(pbkdf2_sha512.max_rounds)) + expected_length += pbkdf2_sha512.max_salt_size + expected_length += pbkdf2_sha512.encoded_checksum_size + expected_length += 4 + + assert kind.length == expected_length + + def test_without_schemes(self): + assert PasswordType(schemes=[]).length == 1024 + + def test_compare(self): + from passlib.hash import md5_crypt + + obj = self.User() + obj.password = Password(md5_crypt.encrypt('b')) + + other = self.User() + other.password = Password(md5_crypt.encrypt('b')) + + # Not sure what to assert here; the test raised an error before. + assert obj.password != other.password + + def test_set_none(self): + + obj = self.User() + obj.password = None + + assert obj.password is None + + self.session.add(obj) + self.session.commit() + + obj = self.session.query(self.User).get(obj.id) + + assert obj.password is None + + def test_update_none(self): + """ + Should be able to change a password from ``None`` to a valid + password. + """ + + obj = self.User() + obj.password = None + + self.session.add(obj) + self.session.commit() + + obj = self.session.query(self.User).get(obj.id) + obj.password = 'b' + + self.session.commit() + + def test_compare_none(self): + """ + Should be able to compare a password of ``None``. + """ + + obj = self.User() + obj.password = None + + assert obj.password is None + assert obj.password == None # noqa + + obj.password = 'b' + + assert obj.password is not None + assert obj.password != None # noqa + + def test_check_and_update_persist(self): + """ + When a password is compared, the hash should update if needed to + change the algorithm; and, commit to the database. + """ + + from passlib.hash import md5_crypt + + obj = self.User() + obj.password = Password(md5_crypt.encrypt('b')) + + self.session.add(obj) + self.session.commit() + + assert obj.password.hash.decode('utf8').startswith('$1$') + assert obj.password == 'b' + + self.session.commit() + + obj = self.session.query(self.User).get(obj.id) + + assert obj.password.hash.decode('utf8').startswith('$pbkdf2-sha512$') + assert obj.password == 'b' diff --git a/python-sqlalchemy-utils/tests/types/test_phonenumber.py b/python-sqlalchemy-utils/tests/types/test_phonenumber.py new file mode 100644 index 0000000..dd8e3dc --- /dev/null +++ b/python-sqlalchemy-utils/tests/types/test_phonenumber.py @@ -0,0 +1,125 @@ +import six +import sqlalchemy as sa +from pytest import mark + +from sqlalchemy_utils import PhoneNumber, PhoneNumberType, types # noqa +from tests import TestCase + + +@mark.skipif('types.phone_number.phonenumbers is None') +class TestPhoneNumber(object): + def setup_method(self, method): + self.valid_phone_numbers = [ + '040 1234567', + '+358 401234567', + '09 2501234', + '+358 92501234', + '0800 939393', + '09 4243 0456', + '0600 900 500' + ] + self.invalid_phone_numbers = [ + 'abc', + '+040 1234567', + '0111234567', + '358' + ] + + def test_valid_phone_numbers(self): + for raw_number in self.valid_phone_numbers: + number = PhoneNumber(raw_number, 'FI') + assert number.is_valid_number() + + def test_invalid_phone_numbers(self): + for raw_number in self.invalid_phone_numbers: + try: + number = PhoneNumber(raw_number, 'FI') + assert not number.is_valid_number() + except: + pass + + def test_phone_number_attributes(self): + number = PhoneNumber('+358401234567') + assert number.e164 == u'+358401234567' + assert number.international == u'+358 40 1234567' + assert number.national == u'040 1234567' + + def test_phone_number_str_repr(self): + number = PhoneNumber('+358401234567') + if six.PY2: + assert unicode(number) == number.national # noqa + assert str(number) == number.national.encode('utf-8') + else: + assert str(number) == number.national + + +@mark.skipif('types.phone_number.phonenumbers is None') +class TestPhoneNumberType(TestCase): + + def create_models(self): + class User(self.Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, autoincrement=True, primary_key=True) + name = sa.Column(sa.Unicode(255)) + phone_number = sa.Column(PhoneNumberType()) + + self.User = User + + def setup_method(self, method): + super(TestPhoneNumberType, self).setup_method(method) + self.phone_number = PhoneNumber( + '040 1234567', + 'FI' + ) + self.user = self.User() + self.user.name = u'Someone' + self.user.phone_number = self.phone_number + self.session.add(self.user) + self.session.commit() + + def test_query_returns_phone_number_object(self): + queried_user = self.session.query(self.User).first() + assert queried_user.phone_number == self.phone_number + + def test_phone_number_is_stored_as_string(self): + result = self.session.execute( + 'SELECT phone_number FROM "user" WHERE id=:param', + {'param': self.user.id} + ) + assert result.first()[0] == u'+358401234567' + + def test_phone_number_with_extension(self): + user = self.User(phone_number='555-555-5555 Ext. 555') + + self.session.add(user) + self.session.commit() + self.session.refresh(user) + assert user.phone_number.extension == '555' + + def test_empty_phone_number_is_equiv_to_none(self): + user = self.User(phone_number='') + + self.session.add(user) + self.session.commit() + self.session.refresh(user) + assert user.phone_number is None + + def test_phone_number_is_none(self): + phone_number = None + user = self.User() + user.name = u'Someone' + user.phone_number = phone_number + self.session.add(user) + self.session.commit() + queried_user = self.session.query(self.User)[1] + assert queried_user.phone_number is None + result = self.session.execute( + 'SELECT phone_number FROM "user" WHERE id=:param', + {'param': user.id} + ) + assert result.first()[0] is None + + def test_scalar_attributes_get_coerced_to_objects(self): + user = self.User(phone_number='050111222') + + assert isinstance(user.phone_number, PhoneNumber) diff --git a/python-sqlalchemy-utils/tests/types/test_scalar_list.py b/python-sqlalchemy-utils/tests/types/test_scalar_list.py new file mode 100644 index 0000000..3f556fc --- /dev/null +++ b/python-sqlalchemy-utils/tests/types/test_scalar_list.py @@ -0,0 +1,79 @@ +import six +import sqlalchemy as sa +from pytest import raises + +from sqlalchemy_utils import ScalarListType +from tests import TestCase + + +class TestScalarIntegerList(TestCase): + def create_models(self): + class User(self.Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + some_list = sa.Column(ScalarListType(int)) + + def __repr__(self): + return 'User(%r)' % self.id + + self.User = User + + def test_save_integer_list(self): + user = self.User( + some_list=[1, 2, 3, 4] + ) + + self.session.add(user) + self.session.commit() + + user = self.session.query(self.User).first() + assert user.some_list == [1, 2, 3, 4] + + +class TestScalarUnicodeList(TestCase): + def create_models(self): + class User(self.Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + some_list = sa.Column(ScalarListType(six.text_type)) + + def __repr__(self): + return 'User(%r)' % self.id + + self.User = User + + def test_throws_exception_if_using_separator_in_list_values(self): + user = self.User( + some_list=[u','] + ) + + self.session.add(user) + with raises(sa.exc.StatementError) as db_err: + self.session.commit() + assert ( + "List values can't contain string ',' (its being used as " + "separator. If you wish for scalar list values to contain " + "these strings, use a different separator string.)" + ) in str(db_err.value) + + def test_save_unicode_list(self): + user = self.User( + some_list=[u'1', u'2', u'3', u'4'] + ) + + self.session.add(user) + self.session.commit() + + user = self.session.query(self.User).first() + assert user.some_list == [u'1', u'2', u'3', u'4'] + + def test_save_and_retrieve_empty_list(self): + user = self.User( + some_list=[] + ) + + self.session.add(user) + self.session.commit() + + user = self.session.query(self.User).first() + assert user.some_list == [] diff --git a/python-sqlalchemy-utils/tests/types/test_timezone.py b/python-sqlalchemy-utils/tests/types/test_timezone.py new file mode 100644 index 0000000..3d8a11c --- /dev/null +++ b/python-sqlalchemy-utils/tests/types/test_timezone.py @@ -0,0 +1,41 @@ +import sqlalchemy as sa + +from sqlalchemy_utils.types import timezone +from tests import TestCase + + +class TestTimezoneType(TestCase): + def create_models(self): + class Visitor(self.Base): + __tablename__ = 'visitor' + id = sa.Column(sa.Integer, primary_key=True) + timezone_dateutil = sa.Column( + timezone.TimezoneType(backend='dateutil') + ) + timezone_pytz = sa.Column( + timezone.TimezoneType(backend='pytz') + ) + + def __repr__(self): + return 'Visitor(%r)' % self.id + + self.Visitor = Visitor + + def test_parameter_processing(self): + visitor = self.Visitor( + timezone_dateutil=u'America/Los_Angeles', + timezone_pytz=u'America/Los_Angeles' + ) + + self.session.add(visitor) + self.session.commit() + + visitor_dateutil = self.session.query(self.Visitor).filter_by( + timezone_dateutil=u'America/Los_Angeles' + ).first() + visitor_pytz = self.session.query(self.Visitor).filter_by( + timezone_pytz=u'America/Los_Angeles' + ).first() + + assert visitor_dateutil is not None + assert visitor_pytz is not None diff --git a/python-sqlalchemy-utils/tests/types/test_tsvector.py b/python-sqlalchemy-utils/tests/types/test_tsvector.py new file mode 100644 index 0000000..5a75a91 --- /dev/null +++ b/python-sqlalchemy-utils/tests/types/test_tsvector.py @@ -0,0 +1,71 @@ +import sqlalchemy as sa +from sqlalchemy.dialects.postgresql import TSVECTOR + +from sqlalchemy_utils import TSVectorType +from tests import TestCase + + +class TestTSVector(TestCase): + dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' + + def create_models(self): + class User(self.Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + search_index = sa.Column( + TSVectorType(name, regconfig='pg_catalog.finnish') + ) + + def __repr__(self): + return 'User(%r)' % self.id + + self.User = User + + def test_generates_table(self): + assert 'search_index' in self.User.__table__.c + + def test_type_reflection(self): + reflected_metadata = sa.schema.MetaData() + table = sa.schema.Table( + 'user', + reflected_metadata, + autoload=True, + autoload_with=self.engine + ) + assert isinstance(table.c['search_index'].type, TSVECTOR) + + def test_catalog_and_columns_as_args(self): + type_ = TSVectorType('name', 'age', regconfig='pg_catalog.simple') + assert type_.columns == ('name', 'age') + assert type_.options['regconfig'] == 'pg_catalog.simple' + + def test_match(self): + expr = self.User.search_index.match(u'something') + assert str(expr.compile(self.connection)) == ( + '''"user".search_index @@ to_tsquery('pg_catalog.finnish', ''' + '''%(search_index_1)s)''' + ) + + def test_concat(self): + assert str(self.User.search_index | self.User.search_index) == ( + '"user".search_index || "user".search_index' + ) + + def test_match_concatenation(self): + concat = self.User.search_index | self.User.search_index + bind = self.session.bind + assert str(concat.match('something').compile(bind)) == ( + '("user".search_index || "user".search_index) @@ ' + "to_tsquery('pg_catalog.finnish', %(param_1)s)" + ) + + def test_match_with_catalog(self): + expr = self.User.search_index.match( + u'something', + postgresql_regconfig='pg_catalog.simple' + ) + assert str(expr.compile(self.connection)) == ( + '''"user".search_index @@ to_tsquery('pg_catalog.simple', ''' + '''%(search_index_1)s)''' + ) diff --git a/python-sqlalchemy-utils/tests/types/test_url.py b/python-sqlalchemy-utils/tests/types/test_url.py new file mode 100644 index 0000000..63970d4 --- /dev/null +++ b/python-sqlalchemy-utils/tests/types/test_url.py @@ -0,0 +1,35 @@ +import sqlalchemy as sa +from pytest import mark + +from sqlalchemy_utils.types import url +from tests import TestCase + + +@mark.skipif('url.furl is None') +class TestURLType(TestCase): + def create_models(self): + class User(self.Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + website = sa.Column(url.URLType) + + def __repr__(self): + return 'User(%r)' % self.id + + self.User = User + + def test_color_parameter_processing(self): + user = self.User( + website=url.furl(u'www.example.com') + ) + + self.session.add(user) + self.session.commit() + + user = self.session.query(self.User).first() + assert isinstance(user.website, url.furl) + + def test_scalar_attributes_get_coerced_to_objects(self): + user = self.User(website=u'www.example.com') + + assert isinstance(user.website, url.furl) diff --git a/python-sqlalchemy-utils/tests/types/test_uuid.py b/python-sqlalchemy-utils/tests/types/test_uuid.py new file mode 100644 index 0000000..5bc4437 --- /dev/null +++ b/python-sqlalchemy-utils/tests/types/test_uuid.py @@ -0,0 +1,41 @@ +import uuid + +import sqlalchemy as sa + +from sqlalchemy_utils import UUIDType +from tests import TestCase + + +class TestUUIDType(TestCase): + def create_models(self): + class User(self.Base): + __tablename__ = 'user' + id = sa.Column(UUIDType, default=uuid.uuid4, primary_key=True) + + def __repr__(self): + return 'User(%r)' % self.id + + self.User = User + + def test_commit(self): + obj = self.User() + obj.id = uuid.uuid4().hex + + self.session.add(obj) + self.session.commit() + + u = self.session.query(self.User).one() + + assert u.id == obj.id + + def test_coerce(self): + obj = self.User() + obj.id = identifier = uuid.uuid4().hex + + assert isinstance(obj.id, uuid.UUID) + assert obj.id.hex == identifier + + obj.id = identifier = uuid.uuid4().bytes + + assert isinstance(obj.id, uuid.UUID) + assert obj.id.bytes == identifier diff --git a/python-sqlalchemy-utils/tests/types/test_weekdays.py b/python-sqlalchemy-utils/tests/types/test_weekdays.py new file mode 100644 index 0000000..1bda6f2 --- /dev/null +++ b/python-sqlalchemy-utils/tests/types/test_weekdays.py @@ -0,0 +1,52 @@ +import pytest +import sqlalchemy as sa + +from sqlalchemy_utils import i18n +from sqlalchemy_utils.primitives import WeekDays +from sqlalchemy_utils.types import WeekDaysType +from tests import TestCase + + +@pytest.mark.skipif('i18n.babel is None') +class WeekDaysTypeTestCase(TestCase): + def setup_method(self, method): + TestCase.setup_method(self, method) + i18n.get_locale = lambda: i18n.babel.Locale('en') + + def create_models(self): + class Schedule(self.Base): + __tablename__ = 'schedule' + id = sa.Column(sa.Integer, primary_key=True) + working_days = sa.Column(WeekDaysType) + + def __repr__(self): + return 'Schedule(%r)' % self.id + + self.Schedule = Schedule + + def test_color_parameter_processing(self): + schedule = self.Schedule( + working_days='0001111' + ) + self.session.add(schedule) + self.session.commit() + + schedule = self.session.query(self.Schedule).first() + assert isinstance(schedule.working_days, WeekDays) + + def test_scalar_attributes_get_coerced_to_objects(self): + schedule = self.Schedule(working_days=u'1010101') + + assert isinstance(schedule.working_days, WeekDays) + + +class TestWeekDaysTypeOnSQLite(WeekDaysTypeTestCase): + dns = 'sqlite:///:memory:' + + +class TestWeekDaysTypeOnPostgres(WeekDaysTypeTestCase): + dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' + + +class TestWeekDaysTypeOnMySQL(WeekDaysTypeTestCase): + dns = 'mysql+pymysql://travis@localhost/sqlalchemy_utils_test' diff --git a/python-sqlalchemy-utils/tox.ini b/python-sqlalchemy-utils/tox.ini new file mode 100644 index 0000000..de28324 --- /dev/null +++ b/python-sqlalchemy-utils/tox.ini @@ -0,0 +1,8 @@ +[tox] +envlist = py26, py27, py33, py34 + +[testenv] +commands = py.test {posargs} +deps = + SQLAlchemy==1.0.4 + .[test_all] diff --git a/tests/runtest.sh b/tests/runtest.sh new file mode 100644 index 0000000..63ae942 --- /dev/null +++ b/tests/runtest.sh @@ -0,0 +1,20 @@ +#!/bin/bash -x + +case $1 in +python-sqlalchemy-utils) + echo "Testing $1" + python -c "import utils" + EC=$? +;; +python3-sqlalchemy-utils) + echo "Testing $1" + python3 -c "import utils" + EC=$? +;; +*) + echo "Test not defined for $1!" + EC=1 +;; +esac + +exit $EC