0001import atexit
0002from cgi import parse_qsl
0003import inspect
0004import new
0005import os
0006import sys
0007import threading
0008import types
0009import urllib
0010import warnings
0011import weakref
0012
0013from cache import CacheSet
0014import classregistry
0015import col
0016from converters import sqlrepr
0017import main
0018import sqlbuilder
0019from util.threadinglocal import local as threading_local
0020
0021warnings.filterwarnings("ignore", "DB-API extension cursor.lastrowid used")
0022
0023_connections = {}
0024
0025def _closeConnection(ref):
0026    conn = ref()
0027    if conn is not None:
0028        conn.close()
0029
0030class ConsoleWriter:
0031    def __init__(self, connection, loglevel):
0032        # loglevel: None or empty string for stdout; or 'stderr'
0033        self.loglevel = loglevel or "stdout"
0034        self.dbEncoding = getattr(connection, "dbEncoding", None) or "ascii"
0035    def write(self, text):
0036        logfile = getattr(sys, self.loglevel)
0037        if isinstance(text, unicode):
0038            try:
0039                text = text.encode(self.dbEncoding)
0040            except UnicodeEncodeError:
0041                text = repr(text)[2:-1] # Remove u'...' from the repr
0042        logfile.write(text + '\n')
0043
0044class LogWriter:
0045    def __init__(self, connection, logger, loglevel):
0046        self.logger = logger
0047        self.loglevel = loglevel
0048        self.logmethod = getattr(logger, loglevel)
0049    def write(self, text):
0050        self.logmethod(text)
0051
0052def makeDebugWriter(connection, loggerName, loglevel):
0053    if not loggerName:
0054        return ConsoleWriter(connection, loglevel)
0055    import logging
0056    logger = logging.getLogger(loggerName)
0057    return LogWriter(connection, logger, loglevel)
0058
0059class Boolean(object):
0060    """A bool class that also understands some special string keywords (yes/no, true/false, on/off, 1/0)"""
0061    _keywords = {'1': True, 'yes': True, 'true': True, 'on': True,
0062                 '0': False, 'no': False, 'false': False, 'off': False}
0063    def __new__(cls, value):
0064        try:
0065            return Boolean._keywords[value.lower()]
0066        except (AttributeError, KeyError):
0067            return bool(value)
0068
0069class DBConnection:
0070
0071    def __init__(self, name=None, debug=False, debugOutput=False,
0072                 cache=True, style=None, autoCommit=True,
0073                 debugThreading=False, registry=None,
0074                 logger=None, loglevel=None):
0075        self.name = name
0076        self.debug = Boolean(debug)
0077        self.debugOutput = Boolean(debugOutput)
0078        self.debugThreading = Boolean(debugThreading)
0079        self.debugWriter = makeDebugWriter(self, logger, loglevel)
0080        self.doCache = Boolean(cache)
0081        self.cache = CacheSet(cache=self.doCache)
0082        self.style = style
0083        self._connectionNumbers = {}
0084        self._connectionCount = 1
0085        self.autoCommit = Boolean(autoCommit)
0086        self.registry = registry or None
0087        classregistry.registry(self.registry).addCallback(self.soClassAdded)
0088        registerConnectionInstance(self)
0089        atexit.register(_closeConnection, weakref.ref(self))
0090
0091    def oldUri(self):
0092        auth = getattr(self, 'user', '') or ''
0093        if auth:
0094            if self.password:
0095                auth = auth + ':' + self.password
0096            auth = auth + '@'
0097        else:
0098            assert not getattr(self, 'password', None), (
0099                'URIs cannot express passwords without usernames')
0100        uri = '%s://%s' % (self.dbName, auth)
0101        if self.host:
0102            uri += self.host
0103            if self.port:
0104                uri += ':%d' % self.port
0105        uri += '/'
0106        db = self.db
0107        if db.startswith('/'):
0108            db = db[1:]
0109        return uri + db
0110
0111    def uri(self):
0112        auth = getattr(self, 'user', '') or ''
0113        if auth:
0114            auth = urllib.quote(auth)
0115            if self.password:
0116                auth = auth + ':' + urllib.quote(self.password)
0117            auth = auth + '@'
0118        else:
0119            assert not getattr(self, 'password', None), (
0120                'URIs cannot express passwords without usernames')
0121        uri = '%s://%s' % (self.dbName, auth)
0122        if self.host:
0123            uri += self.host
0124            if self.port:
0125                uri += ':%d' % self.port
0126        uri += '/'
0127        db = self.db
0128        if db.startswith('/'):
0129            db = db[1:]
0130        return uri + urllib.quote(db)
0131
0132    @classmethod
0133    def connectionFromOldURI(cls, uri):
0134        return cls._connectionFromParams(*cls._parseOldURI(uri))
0135
0136    @classmethod
0137    def connectionFromURI(cls, uri):
0138        return cls._connectionFromParams(*cls._parseURI(uri))
0139
0140    @staticmethod
0141    def _parseOldURI(uri):
0142        schema, rest = uri.split(':', 1)
0143        assert rest.startswith('/'), "URIs must start with scheme:/ -- you did not include a / (in %r)" % rest
0144        if rest.startswith('/') and not rest.startswith('//'):
0145            host = None
0146            rest = rest[1:]
0147        elif rest.startswith('///'):
0148            host = None
0149            rest = rest[3:]
0150        else:
0151            rest = rest[2:]
0152            if rest.find('/') == -1:
0153                host = rest
0154                rest = ''
0155            else:
0156                host, rest = rest.split('/', 1)
0157        if host and host.find('@') != -1:
0158            user, host = host.rsplit('@', 1)
0159            if user.find(':') != -1:
0160                user, password = user.split(':', 1)
0161            else:
0162                password = None
0163        else:
0164            user = password = None
0165        if host and host.find(':') != -1:
0166            _host, port = host.split(':')
0167            try:
0168                port = int(port)
0169            except ValueError:
0170                raise ValueError, "port must be integer, got '%s' instead" % port
0171            if not (1 <= port <= 65535):
0172                raise ValueError, "port must be integer in the range 1-65535, got '%d' instead" % port
0173            host = _host
0174        else:
0175            port = None
0176        path = '/' + rest
0177        if os.name == 'nt':
0178            if (len(rest) > 1) and (rest[1] == '|'):
0179                path = "%s:%s" % (rest[0], rest[2:])
0180        args = {}
0181        if path.find('?') != -1:
0182            path, arglist = path.split('?', 1)
0183            arglist = arglist.split('&')
0184            for single in arglist:
0185                argname, argvalue = single.split('=', 1)
0186                argvalue = urllib.unquote(argvalue)
0187                args[argname] = argvalue
0188        return user, password, host, port, path, args
0189
0190    @staticmethod
0191    def _parseURI(uri):
0192        protocol, request = urllib.splittype(uri)
0193        user, password, port = None, None, None
0194        host, path = urllib.splithost(request)
0195
0196        if host:
0197            # Python < 2.7 have a problem - splituser() calls unquote() too early
0198            #user, host = urllib.splituser(host)
0199            if '@' in host:
0200                user, host = host.split('@', 1)
0201            if user:
0202                user, password = [x and urllib.unquote(x) or None for x in urllib.splitpasswd(user)]
0203            host, port = urllib.splitport(host)
0204            if port: port = int(port)
0205        elif host == '':
0206            host = None
0207
0208        # hash-tag is splitted but ignored
0209        path, tag = urllib.splittag(path)
0210        path, query = urllib.splitquery(path)
0211
0212        path = urllib.unquote(path)
0213        if (os.name == 'nt') and (len(path) > 2):
0214            # Preserve backward compatibility with URIs like /C|/path;
0215            # replace '|' by ':'
0216            if path[2] == '|':
0217                path = "%s:%s" % (path[0:2], path[3:])
0218            # Remove leading slash
0219            if (path[0] == '/') and (path[2] == ':'):
0220                path = path[1:]
0221
0222        args = {}
0223        if query:
0224            for name, value in parse_qsl(query):
0225                args[name] = value
0226
0227        return user, password, host, port, path, args
0228
0229    def soClassAdded(self, soClass):
0230        """
0231        This is called for each new class; we use this opportunity
0232        to create an instance method that is bound to the class
0233        and this connection.
0234        """
0235        name = soClass.__name__
0236        assert not hasattr(self, name), (
0237            "Connection %r already has an attribute with the name "
0238            "%r (and you just created the conflicting class %r)"
0239            % (self, name, soClass))
0240        setattr(self, name, ConnWrapper(soClass, self))
0241
0242    def expireAll(self):
0243        """
0244        Expire all instances of objects for this connection.
0245        """
0246        cache_set = self.cache
0247        cache_set.weakrefAll()
0248        for item in cache_set.getAll():
0249            item.expire()
0250
0251class ConnWrapper(object):
0252
0253    """
0254    This represents a SQLObject class that is bound to a specific
0255    connection (instances have a connection instance variable, but
0256    classes are global, so this is binds the connection variable
0257    lazily when a class method is accessed)
0258    """
0259    # @@: methods that take connection arguments should be explicitly
0260    # marked up instead of the implicit use of a connection argument
0261    # and inspect.getargspec()
0262
0263    def __init__(self, soClass, connection):
0264        self._soClass = soClass
0265        self._connection = connection
0266
0267    def __call__(self, *args, **kw):
0268        kw['connection'] = self._connection
0269        return self._soClass(*args, **kw)
0270
0271    def __getattr__(self, attr):
0272        meth = getattr(self._soClass, attr)
0273        if not isinstance(meth, types.MethodType):
0274            # We don't need to wrap non-methods
0275            return meth
0276        try:
0277            takes_conn = meth.takes_connection
0278        except AttributeError:
0279            args, varargs, varkw, defaults = inspect.getargspec(meth)
0280            assert not varkw and not varargs, (
0281                "I cannot tell whether I must wrap this method, "
0282                "because it takes **kw: %r"
0283                % meth)
0284            takes_conn = 'connection' in args
0285            meth.im_func.takes_connection = takes_conn
0286        if not takes_conn:
0287            return meth
0288        return ConnMethodWrapper(meth, self._connection)
0289
0290class ConnMethodWrapper(object):
0291
0292    def __init__(self, method, connection):
0293        self._method = method
0294        self._connection = connection
0295
0296    def __getattr__(self, attr):
0297        return getattr(self._method, attr)
0298
0299    def __call__(self, *args, **kw):
0300        kw['connection'] = self._connection
0301        return self._method(*args, **kw)
0302
0303    def __repr__(self):
0304        return '<Wrapped %r with connection %r>' % (
0305            self._method, self._connection)
0306
0307class DBAPI(DBConnection):
0308
0309    """
0310    Subclass must define a `makeConnection()` method, which
0311    returns a newly-created connection object.
0312
0313    ``queryInsertID`` must also be defined.
0314    """
0315
0316    dbName = None
0317
0318    def __init__(self, **kw):
0319        self._pool = []
0320        self._poolLock = threading.Lock()
0321        DBConnection.__init__(self, **kw)
0322        self._binaryType = type(self.module.Binary(''))
0323
0324    def _runWithConnection(self, meth, *args):
0325        conn = self.getConnection()
0326        try:
0327            val = meth(conn, *args)
0328        finally:
0329            self.releaseConnection(conn)
0330        return val
0331
0332    def getConnection(self):
0333        self._poolLock.acquire()
0334        try:
0335            if not self._pool:
0336                conn = self.makeConnection()
0337                self._connectionNumbers[id(conn)] = self._connectionCount
0338                self._connectionCount += 1
0339            else:
0340                conn = self._pool.pop()
0341            if self.debug:
0342                s = 'ACQUIRE'
0343                if self._pool is not None:
0344                    s += ' pool=[%s]' % ', '.join([str(self._connectionNumbers[id(v)]) for v in self._pool])
0345                self.printDebug(conn, s, 'Pool')
0346            return conn
0347        finally:
0348            self._poolLock.release()
0349
0350    def releaseConnection(self, conn, explicit=False):
0351        if self.debug:
0352            if explicit:
0353                s = 'RELEASE (explicit)'
0354            else:
0355                s = 'RELEASE (implicit, autocommit=%s)' % self.autoCommit
0356            if self._pool is None:
0357                s += ' no pooling'
0358            else:
0359                s += ' pool=[%s]' % ', '.join([str(self._connectionNumbers[id(v)]) for v in self._pool])
0360            self.printDebug(conn, s, 'Pool')
0361        if self.supportTransactions and not explicit:
0362            if self.autoCommit == 'exception':
0363                if self.debug:
0364                    self.printDebug(conn, 'auto/exception', 'ROLLBACK')
0365                conn.rollback()
0366                raise Exception, 'Object used outside of a transaction; implicit COMMIT or ROLLBACK not allowed'
0367            elif self.autoCommit:
0368                if self.debug:
0369                    self.printDebug(conn, 'auto', 'COMMIT')
0370                if not getattr(conn, 'autocommit', False):
0371                    conn.commit()
0372            else:
0373                if self.debug:
0374                    self.printDebug(conn, 'auto', 'ROLLBACK')
0375                conn.rollback()
0376        if self._pool is not None:
0377            if conn not in self._pool:
0378                # @@: We can get duplicate releasing of connections with
0379                # the __del__ in Iteration (unfortunately, not sure why
0380                # it happens)
0381                self._pool.insert(0, conn)
0382        else:
0383            conn.close()
0384
0385    def printDebug(self, conn, s, name, type='query'):
0386        if name == 'Pool' and self.debug != 'Pool':
0387            return
0388        if type == 'query':
0389            sep = ': '
0390        else:
0391            sep = '->'
0392            s = repr(s)
0393        n = self._connectionNumbers[id(conn)]
0394        spaces = ' '*(8-len(name))
0395        if self.debugThreading:
0396            threadName = threading.currentThread().getName()
0397            threadName = (':' + threadName + ' '*(8-len(threadName)))
0398        else:
0399            threadName = ''
0400        msg = '%(n)2i%(threadName)s/%(name)s%(spaces)s%(sep)s %(s)s' % locals()
0401        self.debugWriter.write(msg)
0402
0403    def _executeRetry(self, conn, cursor, query):
0404        if self.debug:
0405            self.printDebug(conn, query, 'QueryR')
0406        return cursor.execute(query)
0407
0408    def _query(self, conn, s):
0409        if self.debug:
0410            self.printDebug(conn, s, 'Query')
0411        self._executeRetry(conn, conn.cursor(), s)
0412
0413    def query(self, s):
0414        return self._runWithConnection(self._query, s)
0415
0416    def _queryAll(self, conn, s):
0417        if self.debug:
0418            self.printDebug(conn, s, 'QueryAll')
0419        c = conn.cursor()
0420        self._executeRetry(conn, c, s)
0421        value = c.fetchall()
0422        if self.debugOutput:
0423            self.printDebug(conn, value, 'QueryAll', 'result')
0424        return value
0425
0426    def queryAll(self, s):
0427        return self._runWithConnection(self._queryAll, s)
0428
0429    def _queryAllDescription(self, conn, s):
0430        """
0431        Like queryAll, but returns (description, rows), where the
0432        description is cursor.description (which gives row types)
0433        """
0434        if self.debug:
0435            self.printDebug(conn, s, 'QueryAllDesc')
0436        c = conn.cursor()
0437        self._executeRetry(conn, c, s)
0438        value = c.fetchall()
0439        if self.debugOutput:
0440            self.printDebug(conn, value, 'QueryAll', 'result')
0441        return c.description, value
0442
0443    def queryAllDescription(self, s):
0444        return self._runWithConnection(self._queryAllDescription, s)
0445
0446    def _queryOne(self, conn, s):
0447        if self.debug:
0448            self.printDebug(conn, s, 'QueryOne')
0449        c = conn.cursor()
0450        self._executeRetry(conn, c, s)
0451        value = c.fetchone()
0452        if self.debugOutput:
0453            self.printDebug(conn, value, 'QueryOne', 'result')
0454        return value
0455
0456    def queryOne(self, s):
0457        return self._runWithConnection(self._queryOne, s)
0458
0459    def _insertSQL(self, table, names, values):
0460        return ("INSERT INTO %s (%s) VALUES (%s)" %
0461                (table, ', '.join(names),
0462                 ', '.join([self.sqlrepr(v) for v in values])))
0463
0464    def transaction(self):
0465        return Transaction(self)
0466
0467    def queryInsertID(self, soInstance, id, names, values):
0468        return self._runWithConnection(self._queryInsertID, soInstance, id, names, values)
0469
0470    def iterSelect(self, select):
0471        return select.IterationClass(self, self.getConnection(),
0472                         select, keepConnection=False)
0473
0474    def accumulateSelect(self, select, *expressions):
0475        """ Apply an accumulate function(s) (SUM, COUNT, MIN, AVG, MAX, etc...)
0476            to the select object.
0477        """
0478        q = select.queryForSelect().newItems(expressions).unlimited().orderBy(None)
0479        q = self.sqlrepr(q)
0480        val = self.queryOne(q)
0481        if len(expressions) == 1:
0482            val = val[0]
0483        return val
0484
0485    def queryForSelect(self, select):
0486        return self.sqlrepr(select.queryForSelect())
0487
0488    def _SO_createJoinTable(self, join):
0489        self.query(self._SO_createJoinTableSQL(join))
0490
0491    def _SO_createJoinTableSQL(self, join):
0492        return ('CREATE TABLE %s (\n%s %s,\n%s %s\n)' %
0493                (join.intermediateTable,
0494                 join.joinColumn,
0495                 self.joinSQLType(join),
0496                 join.otherColumn,
0497                 self.joinSQLType(join)))
0498
0499    def _SO_dropJoinTable(self, join):
0500        self.query("DROP TABLE %s" % join.intermediateTable)
0501
0502    def _SO_createIndex(self, soClass, index):
0503        self.query(self.createIndexSQL(soClass, index))
0504
0505    def createIndexSQL(self, soClass, index):
0506        assert 0, 'Implement in subclasses'
0507
0508    def createTable(self, soClass):
0509        createSql, constraints = self.createTableSQL(soClass)
0510        self.query(createSql)
0511
0512        return constraints
0513
0514    def createReferenceConstraints(self, soClass):
0515        refConstraints = [self.createReferenceConstraint(soClass, column)                             for column in soClass.sqlmeta.columnList                             if isinstance(column, col.SOForeignKey)]
0518        refConstraintDefs = [constraint                                for constraint in refConstraints                                if constraint]
0521        return refConstraintDefs
0522
0523    def createSQL(self, soClass):
0524        tableCreateSQLs = getattr(soClass.sqlmeta, 'createSQL', None)
0525        if tableCreateSQLs:
0526            assert isinstance(tableCreateSQLs,(str,list,dict,tuple)), (
0527                '%s.sqlmeta.createSQL must be a str, list, dict or tuple.' %
0528                (soClass.__name__))
0529            if isinstance(tableCreateSQLs, dict):
0530                tableCreateSQLs = tableCreateSQLs.get(soClass._connection.dbName, [])
0531            if isinstance(tableCreateSQLs, str):
0532                tableCreateSQLs = [tableCreateSQLs]
0533            if isinstance(tableCreateSQLs, tuple):
0534                tableCreateSQLs = list(tableCreateSQLs)
0535            assert isinstance(tableCreateSQLs,list), (
0536                'Unable to create a list from %s.sqlmeta.createSQL' %
0537                (soClass.__name__))
0538        return tableCreateSQLs or []
0539
0540    def createTableSQL(self, soClass):
0541        constraints = self.createReferenceConstraints(soClass)
0542        extraSQL = self.createSQL(soClass)
0543        createSql = ('CREATE TABLE %s (\n%s\n)' %
0544                (soClass.sqlmeta.table, self.createColumns(soClass)))
0545        return createSql, constraints + extraSQL
0546
0547    def createColumns(self, soClass):
0548        columnDefs = [self.createIDColumn(soClass)]                        + [self.createColumn(soClass, col)
0550                        for col in soClass.sqlmeta.columnList]
0551        return ",\n".join(["    %s" % c for c in columnDefs])
0552
0553    def createReferenceConstraint(self, soClass, col):
0554        assert 0, "Implement in subclasses"
0555
0556    def createColumn(self, soClass, col):
0557        assert 0, "Implement in subclasses"
0558
0559    def dropTable(self, tableName, cascade=False):
0560        self.query("DROP TABLE %s" % tableName)
0561
0562    def clearTable(self, tableName):
0563        # 3-03 @@: Should this have a WHERE 1 = 1 or similar
0564        # clause?  In some configurations without the WHERE clause
0565        # the query won't go through, but maybe we shouldn't override
0566        # that.
0567        self.query("DELETE FROM %s" % tableName)
0568
0569    def createBinary(self, value):
0570        """
0571        Create a binary object wrapper for the given database.
0572        """
0573        # Default is Binary() function from the connection driver.
0574        return self.module.Binary(value)
0575
0576    # The _SO_* series of methods are sorts of "friend" methods
0577    # with SQLObject.  They grab values from the SQLObject instances
0578    # or classes freely, but keep the SQLObject class from accessing
0579    # the database directly.  This way no SQL is actually created
0580    # in the SQLObject class.
0581
0582    def _SO_update(self, so, values):
0583        self.query("UPDATE %s SET %s WHERE %s = (%s)" %
0584                   (so.sqlmeta.table,
0585                    ", ".join(["%s = (%s)" % (dbName, self.sqlrepr(value))
0586                               for dbName, value in values]),
0587                    so.sqlmeta.idName,
0588                    self.sqlrepr(so.id)))
0589
0590    def _SO_selectOne(self, so, columnNames):
0591        return self._SO_selectOneAlt(so, columnNames, so.q.id==so.id)
0592
0593
0594    def _SO_selectOneAlt(self, so, columnNames, condition):
0595        if columnNames:
0596            columns = [isinstance(x, basestring) and sqlbuilder.SQLConstant(x) or x for x in columnNames]
0597        else:
0598            columns = None
0599        return self.queryOne(self.sqlrepr(sqlbuilder.Select(columns,
0600                                                            staticTables=[so.sqlmeta.table],
0601                                                            clause=condition)))
0602
0603    def _SO_delete(self, so):
0604        self.query("DELETE FROM %s WHERE %s = (%s)" %
0605                   (so.sqlmeta.table,
0606                    so.sqlmeta.idName,
0607                    self.sqlrepr(so.id)))
0608
0609    def _SO_selectJoin(self, soClass, column, value):
0610        return self.queryAll("SELECT %s FROM %s WHERE %s = (%s)" %
0611                             (soClass.sqlmeta.idName,
0612                              soClass.sqlmeta.table,
0613                              column,
0614                              self.sqlrepr(value)))
0615
0616    def _SO_intermediateJoin(self, table, getColumn, joinColumn, value):
0617        return self.queryAll("SELECT %s FROM %s WHERE %s = (%s)" %
0618                             (getColumn,
0619                              table,
0620                              joinColumn,
0621                              self.sqlrepr(value)))
0622
0623    def _SO_intermediateDelete(self, table, firstColumn, firstValue,
0624                               secondColumn, secondValue):
0625        self.query("DELETE FROM %s WHERE %s = (%s) AND %s = (%s)" %
0626                   (table,
0627                    firstColumn,
0628                    self.sqlrepr(firstValue),
0629                    secondColumn,
0630                    self.sqlrepr(secondValue)))
0631
0632    def _SO_intermediateInsert(self, table, firstColumn, firstValue,
0633                               secondColumn, secondValue):
0634        self.query("INSERT INTO %s (%s, %s) VALUES (%s, %s)" %
0635                   (table,
0636                    firstColumn,
0637                    secondColumn,
0638                    self.sqlrepr(firstValue),
0639                    self.sqlrepr(secondValue)))
0640
0641    def _SO_columnClause(self, soClass, kw):
0642        ops = {None: "IS"}
0643        data = {}
0644        if 'id' in kw:
0645            data[soClass.sqlmeta.idName] = kw.pop('id')
0646        for key, col in soClass.sqlmeta.columns.items():
0647            if key in kw:
0648                value = kw.pop(key)
0649                if col.from_python:
0650                    value = col.from_python(value, sqlbuilder.SQLObjectState(soClass, connection=self))
0651                data[col.dbName] = value
0652            elif col.foreignName in kw:
0653                obj = kw.pop(col.foreignName)
0654                if isinstance(obj, main.SQLObject):
0655                    data[col.dbName] = obj.id
0656                else:
0657                    data[col.dbName] = obj
0658        if kw:
0659            # pick the first key from kw to use to raise the error,
0660            raise TypeError, "got an unexpected keyword argument(s): %r" % kw.keys()
0661
0662        if not data:
0663            return None
0664        return ' AND '.join(
0665            ['%s %s %s' %
0666             (dbName, ops.get(value, "="), self.sqlrepr(value))
0667             for dbName, value
0668             in data.items()])
0669
0670    def sqlrepr(self, v):
0671        return sqlrepr(v, self.dbName)
0672
0673    def __del__(self):
0674        self.close()
0675
0676    def close(self):
0677        if not hasattr(self, '_pool'):
0678            # Probably there was an exception while creating this
0679            # instance, so it is incomplete.
0680            return
0681        if not self._pool:
0682            return
0683        self._poolLock.acquire()
0684        try:
0685            conns = self._pool[:]
0686            self._pool[:] = []
0687            for conn in conns:
0688                try:
0689                    conn.close()
0690                except self.module.Error:
0691                    pass
0692            del conn
0693            del conns
0694        finally:
0695            self._poolLock.release()
0696
0697    def createEmptyDatabase(self):
0698        """
0699        Create an empty database.
0700        """
0701        raise NotImplementedError
0702
0703class Iteration(object):
0704
0705    def __init__(self, dbconn, rawconn, select, keepConnection=False):
0706        self.dbconn = dbconn
0707        self.rawconn = rawconn
0708        self.select = select
0709        self.keepConnection = keepConnection
0710        self.cursor = rawconn.cursor()
0711        self.query = self.dbconn.queryForSelect(select)
0712        if dbconn.debug:
0713            dbconn.printDebug(rawconn, self.query, 'Select')
0714        self.dbconn._executeRetry(self.rawconn, self.cursor, self.query)
0715
0716    def __iter__(self):
0717        return self
0718
0719    def next(self):
0720        result = self.cursor.fetchone()
0721        if result is None:
0722            self._cleanup()
0723            raise StopIteration
0724        if result[0] is None:
0725            return None
0726        if self.select.ops.get('lazyColumns', 0):
0727            obj = self.select.sourceClass.get(result[0], connection=self.dbconn)
0728            return obj
0729        else:
0730            obj = self.select.sourceClass.get(result[0], selectResults=result[1:], connection=self.dbconn)
0731            return obj
0732
0733    def _cleanup(self):
0734        if getattr(self, 'query', None) is None:
0735            # already cleaned up
0736            return
0737        self.query = None
0738        if not self.keepConnection:
0739            self.dbconn.releaseConnection(self.rawconn)
0740        self.dbconn = self.rawconn = self.select = self.cursor = None
0741
0742    def __del__(self):
0743        self._cleanup()
0744
0745class Transaction(object):
0746
0747    def __init__(self, dbConnection):
0748        # this is to skip __del__ in case of an exception in this __init__
0749        self._obsolete = True
0750        self._dbConnection = dbConnection
0751        self._connection = dbConnection.getConnection()
0752        self._dbConnection._setAutoCommit(self._connection, 0)
0753        self.cache = CacheSet(cache=dbConnection.doCache)
0754        self._deletedCache = {}
0755        self._obsolete = False
0756
0757    def assertActive(self):
0758        assert not self._obsolete, "This transaction has already gone through ROLLBACK; begin another transaction"
0759
0760    def query(self, s):
0761        self.assertActive()
0762        return self._dbConnection._query(self._connection, s)
0763
0764    def queryAll(self, s):
0765        self.assertActive()
0766        return self._dbConnection._queryAll(self._connection, s)
0767
0768    def queryOne(self, s):
0769        self.assertActive()
0770        return self._dbConnection._queryOne(self._connection, s)
0771
0772    def queryInsertID(self, soInstance, id, names, values):
0773        self.assertActive()
0774        return self._dbConnection._queryInsertID(
0775            self._connection, soInstance, id, names, values)
0776
0777    def iterSelect(self, select):
0778        self.assertActive()
0779        # We can't keep the cursor open with results in a transaction,
0780        # because we might want to use the connection while we're
0781        # still iterating through the results.
0782        # @@: But would it be okay for psycopg, with threadsafety
0783        # level 2?
0784        return iter(list(select.IterationClass(self, self._connection,
0785                                   select, keepConnection=True)))
0786
0787    def _SO_delete(self, inst):
0788        cls = inst.__class__.__name__
0789        if not cls in self._deletedCache:
0790            self._deletedCache[cls] = []
0791        self._deletedCache[cls].append(inst.id)
0792        meth = new.instancemethod(self._dbConnection._SO_delete.im_func, self, self.__class__)
0793        return meth(inst)
0794
0795    def commit(self, close=False):
0796        if self._obsolete:
0797            # @@: is it okay to get extraneous commits?
0798            return
0799        if self._dbConnection.debug:
0800            self._dbConnection.printDebug(self._connection, '', 'COMMIT')
0801        self._connection.commit()
0802        subCaches = [(sub[0], sub[1].allIDs()) for sub in self.cache.allSubCachesByClassNames().items()]
0803        subCaches.extend([(x[0], x[1]) for x in self._deletedCache.items()])
0804        for cls, ids in subCaches:
0805            for id in ids:
0806                inst = self._dbConnection.cache.tryGetByName(id, cls)
0807                if inst is not None:
0808                    inst.expire()
0809        if close:
0810            self._makeObsolete()
0811
0812    def rollback(self):
0813        if self._obsolete:
0814            # @@: is it okay to get extraneous rollbacks?
0815            return
0816        if self._dbConnection.debug:
0817            self._dbConnection.printDebug(self._connection, '', 'ROLLBACK')
0818        subCaches = [(sub, sub.allIDs()) for sub in self.cache.allSubCaches()]
0819        self._connection.rollback()
0820
0821        for subCache, ids in subCaches:
0822            for id in ids:
0823                inst = subCache.tryGet(id)
0824                if inst is not None:
0825                    inst.expire()
0826        self._makeObsolete()
0827
0828    def __getattr__(self, attr):
0829        """
0830        If nothing else works, let the parent connection handle it.
0831        Except with this transaction as 'self'.  Poor man's
0832        acquisition?  Bad programming?  Okay, maybe.
0833        """
0834        self.assertActive()
0835        attr = getattr(self._dbConnection, attr)
0836        try:
0837            func = attr.im_func
0838        except AttributeError:
0839            if isinstance(attr, ConnWrapper):
0840                return ConnWrapper(attr._soClass, self)
0841            else:
0842                return attr
0843        else:
0844            meth = new.instancemethod(func, self, self.__class__)
0845            return meth
0846
0847    def _makeObsolete(self):
0848        self._obsolete = True
0849        if self._dbConnection.autoCommit:
0850            self._dbConnection._setAutoCommit(self._connection, 1)
0851        self._dbConnection.releaseConnection(self._connection,
0852                                             explicit=True)
0853        self._connection = None
0854        self._deletedCache = {}
0855
0856    def begin(self):
0857        # @@: Should we do this, or should begin() be a no-op when we're
0858        # not already obsolete?
0859        assert self._obsolete, "You cannot begin a new transaction session without rolling back this one"
0860        self._obsolete = False
0861        self._connection = self._dbConnection.getConnection()
0862        self._dbConnection._setAutoCommit(self._connection, 0)
0863
0864    def __del__(self):
0865        if self._obsolete:
0866            return
0867        self.rollback()
0868
0869    def close(self):
0870        raise TypeError('You cannot just close transaction - you should either call rollback(), commit() or commit(close=True) to close the underlying connection.')
0871
0872class ConnectionHub(object):
0873
0874    """
0875    This object serves as a hub for connections, so that you can pass
0876    in a ConnectionHub to a SQLObject subclass as though it was a
0877    connection, but actually bind a real database connection later.
0878    You can also bind connections on a per-thread basis.
0879
0880    You must hang onto the original ConnectionHub instance, as you
0881    cannot retrieve it again from the class or instance.
0882
0883    To use the hub, do something like::
0884
0885        hub = ConnectionHub()
0886        class MyClass(SQLObject):
0887            _connection = hub
0888
0889        hub.threadConnection = connectionFromURI('...')
0890
0891    """
0892
0893    def __init__(self):
0894        self.threadingLocal = threading_local()
0895
0896    def __get__(self, obj, type=None):
0897        # I'm a little surprised we have to do this, but apparently
0898        # the object's private dictionary of attributes doesn't
0899        # override this descriptor.
0900        if (obj is not None) and '_connection' in obj.__dict__:
0901            return obj.__dict__['_connection']
0902        return self.getConnection()
0903
0904    def __set__(self, obj, value):
0905        obj.__dict__['_connection'] = value
0906
0907    def getConnection(self):
0908        try:
0909            return self.threadingLocal.connection
0910        except AttributeError:
0911            try:
0912                return self.processConnection
0913            except AttributeError:
0914                raise AttributeError(
0915                    "No connection has been defined for this thread "
0916                    "or process")
0917
0918    def doInTransaction(self, func, *args, **kw):
0919        """
0920        This routine can be used to run a function in a transaction,
0921        rolling the transaction back if any exception is raised from
0922        that function, and committing otherwise.
0923
0924        Use like::
0925
0926            sqlhub.doInTransaction(process_request, os.environ)
0927
0928        This will run ``process_request(os.environ)``.  The return
0929        value will be preserved.
0930        """
0931        # @@: In Python 2.5, something usable with with: should also
0932        # be added.
0933        try:
0934            old_conn = self.threadingLocal.connection
0935            old_conn_is_threading = True
0936        except AttributeError:
0937            old_conn = self.processConnection
0938            old_conn_is_threading = False
0939        conn = old_conn.transaction()
0940        if old_conn_is_threading:
0941            self.threadConnection = conn
0942        else:
0943            self.processConnection = conn
0944        try:
0945            try:
0946                value = func(*args, **kw)
0947            except:
0948                conn.rollback()
0949                raise
0950            else:
0951                conn.commit(close=True)
0952                return value
0953        finally:
0954            if old_conn_is_threading:
0955                self.threadConnection = old_conn
0956            else:
0957                self.processConnection = old_conn
0958
0959    def _set_threadConnection(self, value):
0960        self.threadingLocal.connection = value
0961
0962    def _get_threadConnection(self):
0963        return self.threadingLocal.connection
0964
0965    def _del_threadConnection(self):
0966        del self.threadingLocal.connection
0967
0968    threadConnection = property(_get_threadConnection,
0969                                _set_threadConnection,
0970                                _del_threadConnection)
0971
0972class ConnectionURIOpener(object):
0973
0974    def __init__(self):
0975        self.schemeBuilders = {}
0976        self.instanceNames = {}
0977        self.cachedURIs = {}
0978
0979    def registerConnection(self, schemes, builder):
0980        for uriScheme in schemes:
0981            assert not uriScheme in self.schemeBuilders                      or self.schemeBuilders[uriScheme] is builder,                      "A driver has already been registered for the URI scheme %s" % uriScheme
0984            self.schemeBuilders[uriScheme] = builder
0985
0986    def registerConnectionInstance(self, inst):
0987        if inst.name:
0988            assert not inst.name in self.instanceNames                      or self.instanceNames[inst.name] is cls,                      "A instance has already been registered with the name %s" % inst.name
0991            assert inst.name.find(':') == -1, "You cannot include ':' in your class names (%r)" % cls.name
0992            self.instanceNames[inst.name] = inst
0993
0994    def connectionForURI(self, uri, oldUri=False, **args):
0995        if args:
0996            if '?' not in uri:
0997                uri += '?' + urllib.urlencode(args)
0998            else:
0999                uri += '&' + urllib.urlencode(args)
1000        if uri in self.cachedURIs:
1001            return self.cachedURIs[uri]
1002        if uri.find(':') != -1:
1003            scheme, rest = uri.split(':', 1)
1004            connCls = self.dbConnectionForScheme(scheme)
1005            if oldUri:
1006                conn = connCls.connectionFromOldURI(uri)
1007            else:
1008                conn = connCls.connectionFromURI(uri)
1009        else:
1010            # We just have a name, not a URI
1011            assert uri in self.instanceNames,                      "No SQLObject driver exists under the name %s" % uri
1013            conn = self.instanceNames[uri]
1014        # @@: Do we care if we clobber another connection?
1015        self.cachedURIs[uri] = conn
1016        return conn
1017
1018    def dbConnectionForScheme(self, scheme):
1019        assert scheme in self.schemeBuilders, (
1020               "No SQLObject driver exists for %s (only %s)"
1021               % (scheme, ', '.join(self.schemeBuilders.keys())))
1022        return self.schemeBuilders[scheme]()
1023
1024TheURIOpener = ConnectionURIOpener()
1025
1026registerConnection = TheURIOpener.registerConnection
1027registerConnectionInstance = TheURIOpener.registerConnectionInstance
1028connectionForURI = TheURIOpener.connectionForURI
1029dbConnectionForScheme = TheURIOpener.dbConnectionForScheme
1030
1031# Register DB URI schemas
1032import firebird
1033import maxdb
1034import mssql
1035import mysql
1036import postgres
1037import rdbhost
1038import sqlite
1039import sybase