0001import atexit
0002import inspect
0003import sys
0004import os
0005import threading
0006import types
0007try:
0008    from urlparse import urlparse, parse_qsl
0009    from urllib import unquote, quote, urlencode
0010except ImportError:
0011    from urllib.parse import urlparse, parse_qsl, unquote, quote, urlencode
0012
0013import warnings
0014import weakref
0015
0016from .cache import CacheSet
0017from . import classregistry
0018from . import col
0019from .converters import sqlrepr
0020from . import sqlbuilder
0021from .util.threadinglocal import local as threading_local
0022from .compat import PY2, string_type, unicode_type
0023
0024warnings.filterwarnings("ignore", "DB-API extension cursor.lastrowid used")
0025
0026_connections = {}
0027
0028
0029def _closeConnection(ref):
0030    conn = ref()
0031    if conn is not None:
0032        conn.close()
0033
0034
0035class ConsoleWriter:
0036    def __init__(self, connection, loglevel):
0037        # loglevel: None or empty string for stdout; or 'stderr'
0038        self.loglevel = loglevel or "stdout"
0039        self.dbEncoding = getattr(connection, "dbEncoding", None) or "ascii"
0040
0041    def write(self, text):
0042        logfile = getattr(sys, self.loglevel)
0043        if PY2 and isinstance(text, unicode_type):
0044            try:
0045                text = text.encode(self.dbEncoding)
0046            except UnicodeEncodeError:
0047                text = repr(text)[2:-1]  # Remove u'...' from the repr
0048        logfile.write(text + '\n')
0049
0050
0051class LogWriter:
0052    def __init__(self, connection, logger, loglevel):
0053        self.logger = logger
0054        self.loglevel = loglevel
0055        self.logmethod = getattr(logger, loglevel)
0056
0057    def write(self, text):
0058        self.logmethod(text)
0059
0060
0061def makeDebugWriter(connection, loggerName, loglevel):
0062    if not loggerName:
0063        return ConsoleWriter(connection, loglevel)
0064    import logging
0065    logger = logging.getLogger(loggerName)
0066    return LogWriter(connection, logger, loglevel)
0067
0068
0069class Boolean(object):
0070    """A bool class that also understands some special string keywords
0071
0072    Understands: yes/no, true/false, on/off, 1/0, case ignored.
0073
0074    """
0075    _keywords = {'1': True, 'yes': True, 'true': True, 'on': True,
0076                 '0': False, 'no': False, 'false': False, 'off': False}
0077
0078    def __new__(cls, value):
0079        try:
0080            return Boolean._keywords[value.lower()]
0081        except (AttributeError, KeyError):
0082            return bool(value)
0083
0084
0085class DBConnection:
0086
0087    def __init__(self, name=None, debug=False, debugOutput=False,
0088                 cache=True, style=None, autoCommit=True,
0089                 debugThreading=False, registry=None,
0090                 logger=None, loglevel=None):
0091        self.name = name
0092        self.debug = Boolean(debug)
0093        self.debugOutput = Boolean(debugOutput)
0094        self.debugThreading = Boolean(debugThreading)
0095        self.debugWriter = makeDebugWriter(self, logger, loglevel)
0096        self.doCache = Boolean(cache)
0097        self.cache = CacheSet(cache=self.doCache)
0098        self.style = style
0099        self._connectionNumbers = {}
0100        self._connectionCount = 1
0101        self.autoCommit = Boolean(autoCommit)
0102        self.registry = registry or None
0103        classregistry.registry(self.registry).addCallback(self.soClassAdded)
0104        registerConnectionInstance(self)
0105        atexit.register(_closeConnection, weakref.ref(self))
0106
0107    def oldUri(self):
0108        auth = getattr(self, 'user', '') or ''
0109        if auth:
0110            if self.password:
0111                auth = auth + ':' + self.password
0112            auth = auth + '@'
0113        else:
0114            assert not getattr(self, 'password', None), (
0115                'URIs cannot express passwords without usernames')
0116        uri = '%s://%s' % (self.dbName, auth)
0117        if self.host:
0118            uri += self.host
0119            if self.port:
0120                uri += ':%d' % self.port
0121        uri += '/'
0122        db = self.db
0123        if db.startswith('/'):
0124            db = db[1:]
0125        return uri + db
0126
0127    def uri(self):
0128        auth = getattr(self, 'user', '') or ''
0129        if auth:
0130            auth = quote(auth)
0131            if self.password:
0132                auth = auth + ':' + quote(self.password)
0133            auth = auth + '@'
0134        else:
0135            assert not getattr(self, 'password', None), (
0136                'URIs cannot express passwords without usernames')
0137        uri = '%s://%s' % (self.dbName, auth)
0138        if self.host:
0139            uri += self.host
0140            if self.port:
0141                uri += ':%d' % self.port
0142        uri += '/'
0143        db = self.db
0144        if db.startswith('/'):
0145            db = db[1:]
0146        return uri + quote(db)
0147
0148    @classmethod
0149    def connectionFromOldURI(cls, uri):
0150        return cls._connectionFromParams(*cls._parseOldURI(uri))
0151
0152    @classmethod
0153    def connectionFromURI(cls, uri):
0154        return cls._connectionFromParams(*cls._parseURI(uri))
0155
0156    @staticmethod
0157    def _parseOldURI(uri):
0158        schema, rest = uri.split(':', 1)
0159        assert rest.startswith('/'),               "URIs must start with scheme:/ -- "               "you did not include a / (in %r)" % rest
0162        if rest.startswith('/') and not rest.startswith('//'):
0163            host = None
0164            rest = rest[1:]
0165        elif rest.startswith('///'):
0166            host = None
0167            rest = rest[3:]
0168        else:
0169            rest = rest[2:]
0170            if rest.find('/') == -1:
0171                host = rest
0172                rest = ''
0173            else:
0174                host, rest = rest.split('/', 1)
0175        if host and host.find('@') != -1:
0176            user, host = host.rsplit('@', 1)
0177            if user.find(':') != -1:
0178                user, password = user.split(':', 1)
0179            else:
0180                password = None
0181        else:
0182            user = password = None
0183        if host and host.find(':') != -1:
0184            _host, port = host.split(':')
0185            try:
0186                port = int(port)
0187            except ValueError:
0188                raise ValueError("port must be integer, "
0189                                 "got '%s' instead" % port)
0190            if not (1 <= port <= 65535):
0191                raise ValueError("port must be integer in the range 1-65535, "
0192                                 "got '%d' instead" % port)
0193            host = _host
0194        else:
0195            port = None
0196        path = '/' + rest
0197        if os.name == 'nt':
0198            if (len(rest) > 1) and (rest[1] == '|'):
0199                path = "%s:%s" % (rest[0], rest[2:])
0200        args = {}
0201        if path.find('?') != -1:
0202            path, arglist = path.split('?', 1)
0203            arglist = arglist.split('&')
0204            for single in arglist:
0205                argname, argvalue = single.split('=', 1)
0206                argvalue = unquote(argvalue)
0207                args[argname] = argvalue
0208        return user, password, host, port, path, args
0209
0210    @staticmethod
0211    def _parseURI(uri):
0212        parsed = urlparse(uri)
0213        if sys.version_info[0:2] == (2, 6):
0214            # In python 2.6, urlparse only parses the uri completely
0215            # for certain schemes, so we force the scheme to
0216            # something that will be parsed correctly
0217            scheme = parsed.scheme
0218            parsed = urlparse(uri.replace(scheme, 'http', 1))
0219        host, path = parsed.hostname, parsed.path
0220        user, password, port = None, None, None
0221        if parsed.username:
0222            user = unquote(parsed.username)
0223        if parsed.password:
0224            password = unquote(parsed.password)
0225        if parsed.port:
0226            port = int(parsed.port)
0227
0228        path = unquote(path)
0229        if (os.name == 'nt') and (len(path) > 2):
0230            # Preserve backward compatibility with URIs like /C|/path;
0231            # replace '|' by ':'
0232            if path[2] == '|':
0233                path = "%s:%s" % (path[0:2], path[3:])
0234            # Remove leading slash
0235            if (path[0] == '/') and (path[2] == ':'):
0236                path = path[1:]
0237
0238        query = parsed.query
0239        # hash-tag / fragment is ignored
0240
0241        args = {}
0242        if query:
0243            for name, value in parse_qsl(query):
0244                args[name] = value
0245
0246        return user, password, host, port, path, args
0247
0248    def soClassAdded(self, soClass):
0249        """
0250        This is called for each new class; we use this opportunity
0251        to create an instance method that is bound to the class
0252        and this connection.
0253        """
0254        name = soClass.__name__
0255        assert not hasattr(self, name), (
0256            "Connection %r already has an attribute with the name "
0257            "%r (and you just created the conflicting class %r)"
0258            % (self, name, soClass))
0259        setattr(self, name, ConnWrapper(soClass, self))
0260
0261    def expireAll(self):
0262        """
0263        Expire all instances of objects for this connection.
0264        """
0265        cache_set = self.cache
0266        cache_set.weakrefAll()
0267        for item in cache_set.getAll():
0268            item.expire()
0269
0270
0271class ConnWrapper(object):
0272
0273    """
0274    This represents a SQLObject class that is bound to a specific
0275    connection (instances have a connection instance variable, but
0276    classes are global, so this is binds the connection variable
0277    lazily when a class method is accessed)
0278    """
0279    # @@: methods that take connection arguments should be explicitly
0280    # marked up instead of the implicit use of a connection argument
0281    # and inspect.getargspec()
0282
0283    def __init__(self, soClass, connection):
0284        self._soClass = soClass
0285        self._connection = connection
0286
0287    def __call__(self, *args, **kw):
0288        kw['connection'] = self._connection
0289        return self._soClass(*args, **kw)
0290
0291    def __getattr__(self, attr):
0292        meth = getattr(self._soClass, attr)
0293        if not isinstance(meth, types.MethodType):
0294            # We don't need to wrap non-methods
0295            return meth
0296        try:
0297            takes_conn = meth.takes_connection
0298        except AttributeError:
0299            args, varargs, varkw, defaults = inspect.getargspec(meth)
0300            assert not varkw and not varargs, (
0301                "I cannot tell whether I must wrap this method, "
0302                "because it takes **kw: %r"
0303                % meth)
0304            takes_conn = 'connection' in args
0305            meth.__func__.takes_connection = takes_conn
0306        if not takes_conn:
0307            return meth
0308        return ConnMethodWrapper(meth, self._connection)
0309
0310
0311class ConnMethodWrapper(object):
0312
0313    def __init__(self, method, connection):
0314        self._method = method
0315        self._connection = connection
0316
0317    def __getattr__(self, attr):
0318        return getattr(self._method, attr)
0319
0320    def __call__(self, *args, **kw):
0321        kw['connection'] = self._connection
0322        return self._method(*args, **kw)
0323
0324    def __repr__(self):
0325        return '<Wrapped %r with connection %r>' % (
0326            self._method, self._connection)
0327
0328
0329class DBAPI(DBConnection):
0330
0331    """
0332    Subclass must define a `makeConnection()` method, which
0333    returns a newly-created connection object.
0334
0335    ``queryInsertID`` must also be defined.
0336    """
0337    dbName = None
0338
0339    def __init__(self, **kw):
0340        self._pool = []
0341        self._poolLock = threading.Lock()
0342        DBConnection.__init__(self, **kw)
0343        self._binaryType = type(self.module.Binary(b''))
0344
0345    def _runWithConnection(self, meth, *args):
0346        conn = self.getConnection()
0347        try:
0348            val = meth(conn, *args)
0349        finally:
0350            self.releaseConnection(conn)
0351        return val
0352
0353    def getConnection(self):
0354        self._poolLock.acquire()
0355        try:
0356            if not self._pool:
0357                conn = self.makeConnection()
0358                self._connectionNumbers[id(conn)] = self._connectionCount
0359                self._connectionCount += 1
0360            else:
0361                conn = self._pool.pop()
0362            if self.debug:
0363                s = 'ACQUIRE'
0364                if self._pool is not None:
0365                    s += ' pool=[%s]' % ', '.join(
0366                        [str(self._connectionNumbers[id(v)])
0367                            for v in self._pool])
0368                self.printDebug(conn, s, 'Pool')
0369            return conn
0370        finally:
0371            self._poolLock.release()
0372
0373    def releaseConnection(self, conn, explicit=False):
0374        if self.debug:
0375            if explicit:
0376                s = 'RELEASE (explicit)'
0377            else:
0378                s = 'RELEASE (implicit, autocommit=%s)' % self.autoCommit
0379            if self._pool is None:
0380                s += ' no pooling'
0381            else:
0382                s += ' pool=[%s]' % ', '.join(
0383                    [str(self._connectionNumbers[id(v)]) for v in self._pool])
0384            self.printDebug(conn, s, 'Pool')
0385        if self.supportTransactions and not explicit:
0386            if self.autoCommit == 'exception':
0387                if self.debug:
0388                    self.printDebug(conn, 'auto/exception', 'ROLLBACK')
0389                conn.rollback()
0390                raise Exception('Object used outside of a transaction; '
0391                                'implicit COMMIT or ROLLBACK not allowed')
0392            elif self.autoCommit:
0393                if self.debug:
0394                    self.printDebug(conn, 'auto', 'COMMIT')
0395                if not getattr(conn, 'autocommit', False):
0396                    conn.commit()
0397            else:
0398                if self.debug:
0399                    self.printDebug(conn, 'auto', 'ROLLBACK')
0400                conn.rollback()
0401        if self._pool is not None:
0402            if conn not in self._pool:
0403                # @@: We can get duplicate releasing of connections with
0404                # the __del__ in Iteration (unfortunately, not sure why
0405                # it happens)
0406                self._pool.insert(0, conn)
0407        else:
0408            conn.close()
0409
0410    def printDebug(self, conn, s, name, type='query'):
0411        if name == 'Pool' and self.debug != 'Pool':
0412            return
0413        if type == 'query':
0414            sep = ': '
0415        else:
0416            sep = '->'
0417            s = repr(s)
0418        n = self._connectionNumbers[id(conn)]
0419        spaces = ' ' * (8 - len(name))
0420        if self.debugThreading:
0421            threadName = threading.currentThread().getName()
0422            threadName = (':' + threadName + ' ' * (8 - len(threadName)))
0423        else:
0424            threadName = ''
0425        msg = '%(n)2i%(threadName)s/%(name)s%(spaces)s%(sep)s %(s)s' % locals()
0426        self.debugWriter.write(msg)
0427
0428    def _executeRetry(self, conn, cursor, query):
0429        if self.debug:
0430            self.printDebug(conn, query, 'QueryR')
0431        return cursor.execute(query)
0432
0433    def _query(self, conn, s):
0434        if self.debug:
0435            self.printDebug(conn, s, 'Query')
0436        self._executeRetry(conn, conn.cursor(), s)
0437
0438    def query(self, s):
0439        return self._runWithConnection(self._query, s)
0440
0441    def _queryAll(self, conn, s):
0442        if self.debug:
0443            self.printDebug(conn, s, 'QueryAll')
0444        c = conn.cursor()
0445        self._executeRetry(conn, c, s)
0446        value = c.fetchall()
0447        if self.debugOutput:
0448            self.printDebug(conn, value, 'QueryAll', 'result')
0449        return value
0450
0451    def queryAll(self, s):
0452        return self._runWithConnection(self._queryAll, s)
0453
0454    def _queryAllDescription(self, conn, s):
0455        """
0456        Like queryAll, but returns (description, rows), where the
0457        description is cursor.description (which gives row types)
0458        """
0459        if self.debug:
0460            self.printDebug(conn, s, 'QueryAllDesc')
0461        c = conn.cursor()
0462        self._executeRetry(conn, c, s)
0463        value = c.fetchall()
0464        if self.debugOutput:
0465            self.printDebug(conn, value, 'QueryAll', 'result')
0466        return c.description, value
0467
0468    def queryAllDescription(self, s):
0469        return self._runWithConnection(self._queryAllDescription, s)
0470
0471    def _queryOne(self, conn, s):
0472        if self.debug:
0473            self.printDebug(conn, s, 'QueryOne')
0474        c = conn.cursor()
0475        self._executeRetry(conn, c, s)
0476        value = c.fetchone()
0477        if self.debugOutput:
0478            self.printDebug(conn, value, 'QueryOne', 'result')
0479        return value
0480
0481    def queryOne(self, s):
0482        return self._runWithConnection(self._queryOne, s)
0483
0484    def _insertSQL(self, table, names, values):
0485        return ("INSERT INTO %s (%s) VALUES (%s)" %
0486                (table, ', '.join(names),
0487                 ', '.join([self.sqlrepr(v) for v in values])))
0488
0489    def transaction(self):
0490        return Transaction(self)
0491
0492    def queryInsertID(self, soInstance, id, names, values):
0493        return self._runWithConnection(self._queryInsertID, soInstance, id,
0494                                       names, values)
0495
0496    def iterSelect(self, select):
0497        return select.IterationClass(self, self.getConnection(),
0498                                     select, keepConnection=False)
0499
0500    def accumulateSelect(self, select, *expressions):
0501        """ Apply an accumulate function(s) (SUM, COUNT, MIN, AVG, MAX, etc...)
0502            to the select object.
0503        """
0504        q = select.queryForSelect().newItems(expressions).              unlimited().orderBy(None)
0506        q = self.sqlrepr(q)
0507        val = self.queryOne(q)
0508        if len(expressions) == 1:
0509            val = val[0]
0510        return val
0511
0512    def queryForSelect(self, select):
0513        return self.sqlrepr(select.queryForSelect())
0514
0515    def _SO_createJoinTable(self, join):
0516        self.query(self._SO_createJoinTableSQL(join))
0517
0518    def _SO_createJoinTableSQL(self, join):
0519        return ('CREATE TABLE %s (\n%s %s,\n%s %s\n)' %
0520                (join.intermediateTable,
0521                 join.joinColumn,
0522                 self.joinSQLType(join),
0523                 join.otherColumn,
0524                 self.joinSQLType(join)))
0525
0526    def _SO_dropJoinTable(self, join):
0527        self.query("DROP TABLE %s" % join.intermediateTable)
0528
0529    def _SO_createIndex(self, soClass, index):
0530        self.query(self.createIndexSQL(soClass, index))
0531
0532    def createIndexSQL(self, soClass, index):
0533        assert 0, 'Implement in subclasses'
0534
0535    def createTable(self, soClass):
0536        createSql, constraints = self.createTableSQL(soClass)
0537        self.query(createSql)
0538
0539        return constraints
0540
0541    def createReferenceConstraints(self, soClass):
0542        refConstraints = [self.createReferenceConstraint(soClass, column)
0543                          for column in soClass.sqlmeta.columnList
0544                          if isinstance(column, col.SOForeignKey)]
0545        refConstraintDefs = [constraint for constraint in refConstraints
0546                             if constraint]
0547        return refConstraintDefs
0548
0549    def createSQL(self, soClass):
0550        tableCreateSQLs = getattr(soClass.sqlmeta, 'createSQL', None)
0551        if tableCreateSQLs:
0552            assert isinstance(tableCreateSQLs, (str, list, dict, tuple)), (
0553                '%s.sqlmeta.createSQL must be a str, list, dict or tuple.' %
0554                (soClass.__name__))
0555            if isinstance(tableCreateSQLs, dict):
0556                tableCreateSQLs = tableCreateSQLs.get(
0557                    soClass._connection.dbName, [])
0558            if isinstance(tableCreateSQLs, str):
0559                tableCreateSQLs = [tableCreateSQLs]
0560            if isinstance(tableCreateSQLs, tuple):
0561                tableCreateSQLs = list(tableCreateSQLs)
0562            assert isinstance(tableCreateSQLs, list), (
0563                'Unable to create a list from %s.sqlmeta.createSQL' %
0564                (soClass.__name__))
0565        return tableCreateSQLs or []
0566
0567    def createTableSQL(self, soClass):
0568        constraints = self.createReferenceConstraints(soClass)
0569        extraSQL = self.createSQL(soClass)
0570        createSql = ('CREATE TABLE %s (\n%s\n)' %
0571                     (soClass.sqlmeta.table, self.createColumns(soClass)))
0572        return createSql, constraints + extraSQL
0573
0574    def createColumns(self, soClass):
0575        columnDefs = [self.createIDColumn(soClass)]               + [self.createColumn(soClass, col)
0577               for col in soClass.sqlmeta.columnList]
0578        return ",\n".join(["    %s" % c for c in columnDefs])
0579
0580    def createReferenceConstraint(self, soClass, col):
0581        assert 0, "Implement in subclasses"
0582
0583    def createColumn(self, soClass, col):
0584        assert 0, "Implement in subclasses"
0585
0586    def dropTable(self, tableName, cascade=False):
0587        self.query("DROP TABLE %s" % tableName)
0588
0589    def clearTable(self, tableName):
0590        # 3-03 @@: Should this have a WHERE 1 = 1 or similar
0591        # clause?  In some configurations without the WHERE clause
0592        # the query won't go through, but maybe we shouldn't override
0593        # that.
0594        self.query("DELETE FROM %s" % tableName)
0595
0596    def createBinary(self, value):
0597        """
0598        Create a binary object wrapper for the given database.
0599        """
0600        # Default is Binary() function from the connection driver.
0601        return self.module.Binary(value)
0602
0603    # The _SO_* series of methods are sorts of "friend" methods
0604    # with SQLObject.  They grab values from the SQLObject instances
0605    # or classes freely, but keep the SQLObject class from accessing
0606    # the database directly.  This way no SQL is actually created
0607    # in the SQLObject class.
0608
0609    def _SO_update(self, so, values):
0610        self.query("UPDATE %s SET %s WHERE %s = (%s)" %
0611                   (so.sqlmeta.table,
0612                    ", ".join(["%s = (%s)" % (dbName, self.sqlrepr(value))
0613                               for dbName, value in values]),
0614                    so.sqlmeta.idName,
0615                    self.sqlrepr(so.id)))
0616
0617    def _SO_selectOne(self, so, columnNames):
0618        return self._SO_selectOneAlt(so, columnNames, so.q.id == so.id)
0619
0620    def _SO_selectOneAlt(self, so, columnNames, condition):
0621        if columnNames:
0622            columns = [isinstance(x, string_type) and
0623                       sqlbuilder.SQLConstant(x) or
0624                       x for x in columnNames]
0625        else:
0626            columns = None
0627        return self.queryOne(self.sqlrepr(sqlbuilder.Select(
0628            columns, staticTables=[so.sqlmeta.table], clause=condition)))
0629
0630    def _SO_delete(self, so):
0631        self.query("DELETE FROM %s WHERE %s = (%s)" %
0632                   (so.sqlmeta.table,
0633                    so.sqlmeta.idName,
0634                    self.sqlrepr(so.id)))
0635
0636    def _SO_selectJoin(self, soClass, column, value):
0637        return self.queryAll("SELECT %s FROM %s WHERE %s = (%s)" %
0638                             (soClass.sqlmeta.idName,
0639                              soClass.sqlmeta.table,
0640                              column,
0641                              self.sqlrepr(value)))
0642
0643    def _SO_intermediateJoin(self, table, getColumn, joinColumn, value):
0644        return self.queryAll("SELECT %s FROM %s WHERE %s = (%s)" %
0645                             (getColumn,
0646                              table,
0647                              joinColumn,
0648                              self.sqlrepr(value)))
0649
0650    def _SO_intermediateDelete(self, table, firstColumn, firstValue,
0651                               secondColumn, secondValue):
0652        self.query("DELETE FROM %s WHERE %s = (%s) AND %s = (%s)" %
0653                   (table,
0654                    firstColumn,
0655                    self.sqlrepr(firstValue),
0656                    secondColumn,
0657                    self.sqlrepr(secondValue)))
0658
0659    def _SO_intermediateInsert(self, table, firstColumn, firstValue,
0660                               secondColumn, secondValue):
0661        self.query("INSERT INTO %s (%s, %s) VALUES (%s, %s)" %
0662                   (table,
0663                    firstColumn,
0664                    secondColumn,
0665                    self.sqlrepr(firstValue),
0666                    self.sqlrepr(secondValue)))
0667
0668    def _SO_columnClause(self, soClass, kw):
0669        from . import main
0670        ops = {None: "IS"}
0671        data = []
0672        if 'id' in kw:
0673            data.append((soClass.sqlmeta.idName, kw.pop('id')))
0674        for soColumn in soClass.sqlmeta.columnList:
0675            key = soColumn.name
0676            if key in kw:
0677                val = kw.pop(key)
0678                if soColumn.from_python:
0679                    val = soColumn.from_python(
0680                        val,
0681                        sqlbuilder.SQLObjectState(soClass, connection=self))
0682                data.append((soColumn.dbName, val))
0683            elif soColumn.foreignName in kw:
0684                obj = kw.pop(soColumn.foreignName)
0685                if isinstance(obj, main.SQLObject):
0686                    data.append((soColumn.dbName, obj.id))
0687                else:
0688                    data.append((soColumn.dbName, obj))
0689        if kw:
0690            # pick the first key from kw to use to raise the error,
0691            raise TypeError("got an unexpected keyword argument(s): "
0692                            "%r" % kw.keys())
0693
0694        if not data:
0695            return None
0696        return ' AND '.join(
0697            ['%s %s %s' %
0698             (dbName, ops.get(value, "="), self.sqlrepr(value))
0699             for dbName, value
0700             in data])
0701
0702    def sqlrepr(self, v):
0703        return sqlrepr(v, self.dbName)
0704
0705    def __del__(self):
0706        self.close()
0707
0708    def close(self):
0709        if not hasattr(self, '_pool'):
0710            # Probably there was an exception while creating this
0711            # instance, so it is incomplete.
0712            return
0713        if not self._pool:
0714            return
0715        self._poolLock.acquire()
0716        try:
0717            if not self._pool:  # _pool could be filled in a different thread
0718                return
0719            conns = self._pool[:]
0720            self._pool[:] = []
0721            for conn in conns:
0722                try:
0723                    conn.close()
0724                except self.module.Error:
0725                    pass
0726            del conn
0727            del conns
0728        finally:
0729            self._poolLock.release()
0730
0731    def createEmptyDatabase(self):
0732        """
0733        Create an empty database.
0734        """
0735        raise NotImplementedError
0736
0737
0738class Iteration(object):
0739
0740    def __init__(self, dbconn, rawconn, select, keepConnection=False):
0741        self.dbconn = dbconn
0742        self.rawconn = rawconn
0743        self.select = select
0744        self.keepConnection = keepConnection
0745        self.cursor = rawconn.cursor()
0746        self.query = self.dbconn.queryForSelect(select)
0747        if dbconn.debug:
0748            dbconn.printDebug(rawconn, self.query, 'Select')
0749        self.dbconn._executeRetry(self.rawconn, self.cursor, self.query)
0750
0751    def __iter__(self):
0752        return self
0753
0754    def __next__(self):
0755        return self.next()
0756
0757    def next(self):
0758        result = self.cursor.fetchone()
0759        if result is None:
0760            self._cleanup()
0761            raise StopIteration
0762        if result[0] is None:
0763            return None
0764        if self.select.ops.get('lazyColumns', 0):
0765            obj = self.select.sourceClass.get(result[0],
0766                                              connection=self.dbconn)
0767            return obj
0768        else:
0769            obj = self.select.sourceClass.get(result[0],
0770                                              selectResults=result[1:],
0771                                              connection=self.dbconn)
0772            return obj
0773
0774    def _cleanup(self):
0775        if getattr(self, 'query', None) is None:
0776            # already cleaned up
0777            return
0778        self.query = None
0779        if not self.keepConnection:
0780            self.dbconn.releaseConnection(self.rawconn)
0781        self.dbconn = self.rawconn = self.select = self.cursor = None
0782
0783    def __del__(self):
0784        self._cleanup()
0785
0786
0787class Transaction(object):
0788
0789    def __init__(self, dbConnection):
0790        # this is to skip __del__ in case of an exception in this __init__
0791        self._obsolete = True
0792        self._dbConnection = dbConnection
0793        self._connection = dbConnection.getConnection()
0794        self._dbConnection._setAutoCommit(self._connection, 0)
0795        self.cache = CacheSet(cache=dbConnection.doCache)
0796        self._deletedCache = {}
0797        self._obsolete = False
0798
0799    def assertActive(self):
0800        assert not self._obsolete,               "This transaction has already gone through ROLLBACK; "               "begin another transaction"
0803
0804    def query(self, s):
0805        self.assertActive()
0806        return self._dbConnection._query(self._connection, s)
0807
0808    def queryAll(self, s):
0809        self.assertActive()
0810        return self._dbConnection._queryAll(self._connection, s)
0811
0812    def queryOne(self, s):
0813        self.assertActive()
0814        return self._dbConnection._queryOne(self._connection, s)
0815
0816    def queryInsertID(self, soInstance, id, names, values):
0817        self.assertActive()
0818        return self._dbConnection._queryInsertID(
0819            self._connection, soInstance, id, names, values)
0820
0821    def iterSelect(self, select):
0822        self.assertActive()
0823        # We can't keep the cursor open with results in a transaction,
0824        # because we might want to use the connection while we're
0825        # still iterating through the results.
0826        # @@: But would it be okay for psycopg, with threadsafety
0827        # level 2?
0828        return iter(list(select.IterationClass(self, self._connection,
0829                                               select, keepConnection=True)))
0830
0831    def _SO_delete(self, inst):
0832        cls = inst.__class__.__name__
0833        if cls not in self._deletedCache:
0834            self._deletedCache[cls] = []
0835        self._deletedCache[cls].append(inst.id)
0836        if PY2:
0837            meth = types.MethodType(self._dbConnection._SO_delete.__func__,
0838                                    self, self.__class__)
0839        else:
0840            meth = types.MethodType(self._dbConnection._SO_delete.__func__,
0841                                    self)
0842        return meth(inst)
0843
0844    def commit(self, close=False):
0845        if self._obsolete:
0846            # @@: is it okay to get extraneous commits?
0847            return
0848        if self._dbConnection.debug:
0849            self._dbConnection.printDebug(self._connection, '', 'COMMIT')
0850        self._connection.commit()
0851        subCaches = [(sub[0], sub[1].allIDs())
0852                     for sub in self.cache.allSubCachesByClassNames().items()]
0853        subCaches.extend([(x[0], x[1]) for x in self._deletedCache.items()])
0854        for cls, ids in subCaches:
0855            for id in list(ids):
0856                inst = self._dbConnection.cache.tryGetByName(id, cls)
0857                if inst is not None:
0858                    inst.expire()
0859        if close:
0860            self._makeObsolete()
0861
0862    def rollback(self):
0863        if self._obsolete:
0864            # @@: is it okay to get extraneous rollbacks?
0865            return
0866        if self._dbConnection.debug:
0867            self._dbConnection.printDebug(self._connection, '', 'ROLLBACK')
0868        subCaches = [(sub, sub.allIDs()) for sub in self.cache.allSubCaches()]
0869        self._connection.rollback()
0870
0871        for subCache, ids in subCaches:
0872            for id in list(ids):
0873                inst = subCache.tryGet(id)
0874                if inst is not None:
0875                    inst.expire()
0876        self._makeObsolete()
0877
0878    def __getattr__(self, attr):
0879        """
0880        If nothing else works, let the parent connection handle it.
0881        Except with this transaction as 'self'.  Poor man's
0882        acquisition?  Bad programming?  Okay, maybe.
0883        """
0884        self.assertActive()
0885        attr = getattr(self._dbConnection, attr)
0886        try:
0887            func = attr.__func__
0888        except AttributeError:
0889            if isinstance(attr, ConnWrapper):
0890                return ConnWrapper(attr._soClass, self)
0891            else:
0892                return attr
0893        else:
0894            if PY2:
0895                meth = types.MethodType(func, self, self.__class__)
0896            else:
0897                meth = types.MethodType(func, self)
0898            return meth
0899
0900    def _makeObsolete(self):
0901        self._obsolete = True
0902        if self._dbConnection.autoCommit:
0903            self._dbConnection._setAutoCommit(self._connection, 1)
0904        self._dbConnection.releaseConnection(self._connection,
0905                                             explicit=True)
0906        self._connection = None
0907        self._deletedCache = {}
0908
0909    def begin(self):
0910        # @@: Should we do this, or should begin() be a no-op when we're
0911        # not already obsolete?
0912        assert self._obsolete,               "You cannot begin a new transaction session "               "without rolling back this one"
0915        self._obsolete = False
0916        self._connection = self._dbConnection.getConnection()
0917        self._dbConnection._setAutoCommit(self._connection, 0)
0918
0919    def __del__(self):
0920        if self._obsolete:
0921            return
0922        self.rollback()
0923
0924    def close(self):
0925        raise TypeError('You cannot just close transaction - '
0926                        'you should either call rollback(), commit() '
0927                        'or commit(close=True) '
0928                        'to close the underlying connection.')
0929
0930
0931class ConnectionHub(object):
0932
0933    """
0934    This object serves as a hub for connections, so that you can pass
0935    in a ConnectionHub to a SQLObject subclass as though it was a
0936    connection, but actually bind a real database connection later.
0937    You can also bind connections on a per-thread basis.
0938
0939    You must hang onto the original ConnectionHub instance, as you
0940    cannot retrieve it again from the class or instance.
0941
0942    To use the hub, do something like::
0943
0944        hub = ConnectionHub()
0945        class MyClass(SQLObject):
0946            _connection = hub
0947
0948        hub.threadConnection = connectionFromURI('...')
0949
0950    """
0951
0952    def __init__(self):
0953        self.threadingLocal = threading_local()
0954
0955    def __get__(self, obj, type=None):
0956        # I'm a little surprised we have to do this, but apparently
0957        # the object's private dictionary of attributes doesn't
0958        # override this descriptor.
0959        if (obj is not None) and '_connection' in obj.__dict__:
0960            return obj.__dict__['_connection']
0961        return self.getConnection()
0962
0963    def __set__(self, obj, value):
0964        obj.__dict__['_connection'] = value
0965
0966    def getConnection(self):
0967        try:
0968            return self.threadingLocal.connection
0969        except AttributeError:
0970            try:
0971                return self.processConnection
0972            except AttributeError:
0973                raise AttributeError(
0974                    "No connection has been defined for this thread "
0975                    "or process")
0976
0977    def doInTransaction(self, func, *args, **kw):
0978        """
0979        This routine can be used to run a function in a transaction,
0980        rolling the transaction back if any exception is raised from
0981        that function, and committing otherwise.
0982
0983        Use like::
0984
0985            sqlhub.doInTransaction(process_request, os.environ)
0986
0987        This will run ``process_request(os.environ)``.  The return
0988        value will be preserved.
0989        """
0990        # @@: In Python 2.5, something usable with with: should also
0991        # be added.
0992        try:
0993            old_conn = self.threadingLocal.connection
0994            old_conn_is_threading = True
0995        except AttributeError:
0996            old_conn = self.processConnection
0997            old_conn_is_threading = False
0998        conn = old_conn.transaction()
0999        if old_conn_is_threading:
1000            self.threadConnection = conn
1001        else:
1002            self.processConnection = conn
1003        try:
1004            try:
1005                value = func(*args, **kw)
1006            except:
1007                conn.rollback()
1008                raise
1009            else:
1010                conn.commit(close=True)
1011                return value
1012        finally:
1013            if old_conn_is_threading:
1014                self.threadConnection = old_conn
1015            else:
1016                self.processConnection = old_conn
1017
1018    def _set_threadConnection(self, value):
1019        self.threadingLocal.connection = value
1020
1021    def _get_threadConnection(self):
1022        return self.threadingLocal.connection
1023
1024    def _del_threadConnection(self):
1025        del self.threadingLocal.connection
1026
1027    threadConnection = property(_get_threadConnection,
1028                                _set_threadConnection,
1029                                _del_threadConnection)
1030
1031
1032class ConnectionURIOpener(object):
1033
1034    def __init__(self):
1035        self.schemeBuilders = {}
1036        self.instanceNames = {}
1037        self.cachedURIs = {}
1038
1039    def registerConnection(self, schemes, builder):
1040        for uriScheme in schemes:
1041            assert uriScheme not in self.schemeBuilders                   or self.schemeBuilders[uriScheme] is builder,                   "A driver has already been registered "                   "for the URI scheme %s" % uriScheme
1045            self.schemeBuilders[uriScheme] = builder
1046
1047    def registerConnectionInstance(self, inst):
1048        if inst.name:
1049            assert (inst.name not in self.instanceNames or
1050                    self.instanceNames[inst.name] is cls  # noqa
1051                ), ("A instance has already been registered "
1052                "with the name %s" % inst.name)
1053            assert inst.name.find(':') == -1,                   "You cannot include ':' "                   "in your class names (%r)" % cls.name  # noqa
1056            self.instanceNames[inst.name] = inst
1057
1058    def connectionForURI(self, uri, oldUri=False, **args):
1059        if args:
1060            if '?' not in uri:
1061                uri += '?' + urlencode(args)
1062            else:
1063                uri += '&' + urlencode(args)
1064        if uri in self.cachedURIs:
1065            return self.cachedURIs[uri]
1066        if uri.find(':') != -1:
1067            scheme, rest = uri.split(':', 1)
1068            connCls = self.dbConnectionForScheme(scheme)
1069            if oldUri:
1070                conn = connCls.connectionFromOldURI(uri)
1071            else:
1072                conn = connCls.connectionFromURI(uri)
1073        else:
1074            # We just have a name, not a URI
1075            assert uri in self.instanceNames,                   "No SQLObject driver exists under the name %s" % uri
1077            conn = self.instanceNames[uri]
1078        # @@: Do we care if we clobber another connection?
1079        self.cachedURIs[uri] = conn
1080        return conn
1081
1082    def dbConnectionForScheme(self, scheme):
1083        assert scheme in self.schemeBuilders, (
1084            "No SQLObject driver exists for %s (only %s)" % (
1085                scheme,
1086                ', '.join(self.schemeBuilders.keys())))
1087        return self.schemeBuilders[scheme]()
1088
1089TheURIOpener = ConnectionURIOpener()
1090
1091registerConnection = TheURIOpener.registerConnection
1092registerConnectionInstance = TheURIOpener.registerConnectionInstance
1093connectionForURI = TheURIOpener.connectionForURI
1094dbConnectionForScheme = TheURIOpener.dbConnectionForScheme
1095
1096# Register DB URI schemas -- do import for side effects
1097# noqa is a directive for flake8 to ignore seemingly unused imports
1098from . import firebird  # noqa
1099from . import maxdb  # noqa
1100from . import mssql  # noqa
1101from . import mysql  # noqa
1102from . import postgres  # noqa
1103from . import rdbhost  # noqa
1104from . import sqlite  # noqa
1105from . import sybase  # noqa