diff --git a/neunton/common/exceptions.py b/neunton/common/exceptions.py index c99c254..e24f7bc 100644 --- a/neunton/common/exceptions.py +++ b/neunton/common/exceptions.py @@ -235,3 +235,7 @@ class InvalidSharedSetting(QuantumException): class InvalidExtenstionEnv(QuantumException): message = _("Invalid extension environment: %(reason)s") + +class DBError(Error): + message = _("Database error") + diff --git a/neunton/db/api.py b/neunton/db/api.py index 238a9f9..737c748 100644 --- a/neunton/db/api.py +++ b/neunton/db/api.py @@ -20,12 +20,16 @@ import logging import time +import time + import sqlalchemy as sql from sqlalchemy import create_engine from sqlalchemy.exc import DisconnectionError +from sqlalchemy.exc import OperationalError from sqlalchemy.orm import sessionmaker, exc from quantum.db import model_base +from quantum.common.exceptions import DBError LOG = logging.getLogger(__name__) @@ -33,28 +37,61 @@ LOG = logging.getLogger(__name__) _ENGINE = None _MAKER = None BASE = model_base.BASE +OPTIONS = None +def is_db_connection_error(args): + """Return True if error in connecting to db.""" + # NOTE(adam_g): This is currently MySQL specific and needs to be extended + # to support Postgres and others. + conn_err_codes = ('2002', '2003', '2006', '2013', '2014', '2045', '2055') + for err_code in conn_err_codes: + if args.find(err_code) != -1: + return True + return False -class MySQLPingListener(object): - """ - Ensures that MySQL connections checked out of the - pool are alive. +def wrap_db_error(f): + """Function wrapper to capture DB errors - Borrowed from: - http://groups.google.com/group/sqlalchemy/msg/a4ce563d802c929f - """ + If an exception is thrown by the wrapped function, + determine if it represents a database connection error. + If so, retry the wrapped function, and repeat until it succeeds + or we reach a configurable maximum number of retries. + If it is not a connection error, or we exceeded the retry limit, + raise a DBError. - def checkout(self, dbapi_con, con_record, con_proxy): - try: - dbapi_con.cursor().execute('select 1') - except dbapi_con.OperationalError, ex: - if ex.args[0] in (2006, 2013, 2014, 2045, 2055): - LOG.warn('Got mysql server has gone away: %s', ex) - raise DisconnectionError("Database server went away") - else: + """ + global OPTIONS + def _wrap_db_error(*args, **kwargs): + next_interval = OPTIONS.get('reconnect_interval', 1) + remaining = OPTIONS.get('sql_max_retries', -1) + if remaining == -1: + remaining = 'infinite' + while True: + try: + return f(*args, **kwargs) + except OperationalError, e: + if is_db_connection_error(e.args[0]): + if remaining == 0: + logging.warn('DB exceeded retry limit.') + raise DBError(e) + if remaining != 'infinite': + remaining -= 1 + logging.warn('DB connection error, ' + 'retrying in %i seconds.' % next_interval) + time.sleep(next_interval) + if OPTIONS.get('inc_reconnect_interval', True): + next_interval = min(next_interval * 2, + OPTIONS.get('max_reconnect_interval', 60)) + else: + logging.warn('DB exception wrapped.') + raise DBError(e) + except Exception, e: raise + _wrap_db_error.func_name = f.func_name + return _wrap_db_error + def configure_db(options): """ @@ -63,6 +100,8 @@ def configure_db(options): :param options: Mapping of configuration options """ + global OPTIONS + OPTIONS = options global _ENGINE if not _ENGINE: connection_dict = sql.engine.url.make_url(options['sql_connection']) @@ -72,9 +111,6 @@ def configure_db(options): 'convert_unicode': True, } - if 'mysql' in connection_dict.drivername: - engine_args['listeners'] = [MySQLPingListener()] - _ENGINE = create_engine(options['sql_connection'], **engine_args) base = options.get('base', BASE) if not register_models(base): @@ -101,10 +137,18 @@ def get_session(autocommit=True, expire_on_commit=False): global _MAKER, _ENGINE if not _MAKER: assert _ENGINE + class OurQuery(sql.orm.query.Query): + pass + query = OurQuery + query.all = wrap_db_error(query.all) + query.first = wrap_db_error(query.first) _MAKER = sessionmaker(bind=_ENGINE, autocommit=autocommit, - expire_on_commit=expire_on_commit) - return _MAKER() + expire_on_commit=expire_on_commit, + query_cls=OurQuery) + session = _MAKER() + session.flush = wrap_db_error(session.flush) + return session def retry_registration(remaining, reconnect_interval, base=BASE):