0001import atexit
0002from cgi import parse_qsl
0003import inspect
0004import new
0005import os
0006import sys
0007import threading
0008import types
0009import urllib
0010import warnings
0011import weakref
0012
0013from cache import CacheSet
0014import classregistry
0015import col
0016from converters import sqlrepr
0017import main
0018import sqlbuilder
0019from util.threadinglocal import local as threading_local
0020
0021warnings.filterwarnings("ignore", "DB-API extension cursor.lastrowid used")
0022
0023_connections = {}
0024
0025def _closeConnection(ref):
0026    conn = ref()
0027    if conn is not None:
0028        conn.close()
0029
0030class ConsoleWriter:
0031    def __init__(self, connection, loglevel):
0032        # loglevel: None or empty string for stdout; or 'stderr'
0033        self.loglevel = loglevel or "stdout"
0034        self.dbEncoding = getattr(connection, "dbEncoding", None) or "ascii"
0035    def write(self, text):
0036        logfile = getattr(sys, self.loglevel)
0037        if isinstance(text, unicode):
0038            try:
0039                text = text.encode(self.dbEncoding)
0040            except UnicodeEncodeError:
0041                text = repr(text)[2:-1] # Remove u'...' from the repr
0042        logfile.write(text + '\n')
0043
0044class LogWriter:
0045    def __init__(self, connection, logger, loglevel):
0046        self.logger = logger
0047        self.loglevel = loglevel
0048        self.logmethod = getattr(logger, loglevel)
0049    def write(self, text):
0050        self.logmethod(text)
0051
0052def makeDebugWriter(connection, loggerName, loglevel):
0053    if not loggerName:
0054        return ConsoleWriter(connection, loglevel)
0055    import logging
0056    logger = logging.getLogger(loggerName)
0057    return LogWriter(connection, logger, loglevel)
0058
0059class Boolean(object):
0060    """A bool class that also understands some special string keywords (yes/no, true/false, on/off, 1/0)"""
0061    _keywords = {'1': True, 'yes': True, 'true': True, 'on': True,
0062                 '0': False, 'no': False, 'false': False, 'off': False}
0063    def __new__(cls, value):
0064        try:
0065            return Boolean._keywords[value.lower()]
0066        except (AttributeError, KeyError):
0067            return bool(value)
0068
0069class DBConnection:
0070
0071    def __init__(self, name=None, debug=False, debugOutput=False,
0072                 cache=True, style=None, autoCommit=True,
0073                 debugThreading=False, registry=None,
0074                 logger=None, loglevel=None):
0075        self.name = name
0076        self.debug = Boolean(debug)
0077        self.debugOutput = Boolean(debugOutput)
0078        self.debugThreading = Boolean(debugThreading)
0079        self.debugWriter = makeDebugWriter(self, logger, loglevel)
0080        self.doCache = Boolean(cache)
0081        self.cache = CacheSet(cache=self.doCache)
0082        self.style = style
0083        self._connectionNumbers = {}
0084        self._connectionCount = 1
0085        self.autoCommit = Boolean(autoCommit)
0086        self.registry = registry or None
0087        classregistry.registry(self.registry).addCallback(self.soClassAdded)
0088        registerConnectionInstance(self)
0089        atexit.register(_closeConnection, weakref.ref(self))
0090
0091    def oldUri(self):
0092        auth = getattr(self, 'user', '') or ''
0093        if auth:
0094            if self.password:
0095                auth = auth + ':' + self.password
0096            auth = auth + '@'
0097        else:
0098            assert not getattr(self, 'password', None), (
0099                'URIs cannot express passwords without usernames')
0100        uri = '%s://%s' % (self.dbName, auth)
0101        if self.host:
0102            uri += self.host
0103            if self.port:
0104                uri += ':%d' % self.port
0105        uri += '/'
0106        db = self.db
0107        if db.startswith('/'):
0108            db = db[1:]
0109        return uri + db
0110
0111    def uri(self):
0112        auth = getattr(self, 'user', '') or ''
0113        if auth:
0114            auth = urllib.quote(auth)
0115            if self.password:
0116                auth = auth + ':' + urllib.quote(self.password)
0117            auth = auth + '@'
0118        else:
0119            assert not getattr(self, 'password', None), (
0120                'URIs cannot express passwords without usernames')
0121        uri = '%s://%s' % (self.dbName, auth)
0122        if self.host:
0123            uri += self.host
0124            if self.port:
0125                uri += ':%d' % self.port
0126        uri += '/'
0127        db = self.db
0128        if db.startswith('/'):
0129            db = db[1:]
0130        return uri + urllib.quote(db)
0131
0132    @classmethod
0133    def connectionFromOldURI(cls, uri):
0134        return cls._connectionFromParams(*cls._parseOldURI(uri))
0135
0136    @classmethod
0137    def connectionFromURI(cls, uri):
0138        return cls._connectionFromParams(*cls._parseURI(uri))
0139
0140    @staticmethod
0141    def _parseOldURI(uri):
0142        schema, rest = uri.split(':', 1)
0143        assert rest.startswith('/'), "URIs must start with scheme:/ -- you did not include a / (in %r)" % rest
0144        if rest.startswith('/') and not rest.startswith('//'):
0145            host = None
0146            rest = rest[1:]
0147        elif rest.startswith('///'):
0148            host = None
0149            rest = rest[3:]
0150        else:
0151            rest = rest[2:]
0152            if rest.find('/') == -1:
0153                host = rest
0154                rest = ''
0155            else:
0156                host, rest = rest.split('/', 1)
0157        if host and host.find('@') != -1:
0158            user, host = host.rsplit('@', 1)
0159            if user.find(':') != -1:
0160                user, password = user.split(':', 1)
0161            else:
0162                password = None
0163        else:
0164            user = password = None
0165        if host and host.find(':') != -1:
0166            _host, port = host.split(':')
0167            try:
0168                port = int(port)
0169            except ValueError:
0170                raise ValueError("port must be integer, got '%s' instead" % port)
0171            if not (1 <= port <= 65535):
0172                raise ValueError("port must be integer in the range 1-65535, got '%d' instead" % port)
0173            host = _host
0174        else:
0175            port = None
0176        path = '/' + rest
0177        if os.name == 'nt':
0178            if (len(rest) > 1) and (rest[1] == '|'):
0179                path = "%s:%s" % (rest[0], rest[2:])
0180        args = {}
0181        if path.find('?') != -1:
0182            path, arglist = path.split('?', 1)
0183            arglist = arglist.split('&')
0184            for single in arglist:
0185                argname, argvalue = single.split('=', 1)
0186                argvalue = urllib.unquote(argvalue)
0187                args[argname] = argvalue
0188        return user, password, host, port, path, args
0189
0190    @staticmethod
0191    def _parseURI(uri):
0192        protocol, request = urllib.splittype(uri)
0193        user, password, port = None, None, None
0194        host, path = urllib.splithost(request)
0195
0196        if host:
0197            # Python < 2.7 have a problem - splituser() calls unquote() too early
0198            #user, host = urllib.splituser(host)
0199            if '@' in host:
0200                user, host = host.split('@', 1)
0201            if user:
0202                user, password = [x and urllib.unquote(x) or None for x in urllib.splitpasswd(user)]
0203            host, port = urllib.splitport(host)
0204            if port: port = int(port)
0205        elif host == '':
0206            host = None
0207
0208        # hash-tag is splitted but ignored
0209        path, tag = urllib.splittag(path)
0210        path, query = urllib.splitquery(path)
0211
0212        path = urllib.unquote(path)
0213        if (os.name == 'nt') and (len(path) > 2):
0214            # Preserve backward compatibility with URIs like /C|/path;
0215            # replace '|' by ':'
0216            if path[2] == '|':
0217                path = "%s:%s" % (path[0:2], path[3:])
0218            # Remove leading slash
0219            if (path[0] == '/') and (path[2] == ':'):
0220                path = path[1:]
0221
0222        args = {}
0223        if query:
0224            for name, value in parse_qsl(query):
0225                args[name] = value
0226
0227        return user, password, host, port, path, args
0228
0229    def soClassAdded(self, soClass):
0230        """
0231        This is called for each new class; we use this opportunity
0232        to create an instance method that is bound to the class
0233        and this connection.
0234        """
0235        name = soClass.__name__
0236        assert not hasattr(self, name), (
0237            "Connection %r already has an attribute with the name "
0238            "%r (and you just created the conflicting class %r)"
0239            % (self, name, soClass))
0240        setattr(self, name, ConnWrapper(soClass, self))
0241
0242    def expireAll(self):
0243        """
0244        Expire all instances of objects for this connection.
0245        """
0246        cache_set = self.cache
0247        cache_set.weakrefAll()
0248        for item in cache_set.getAll():
0249            item.expire()
0250
0251class ConnWrapper(object):
0252
0253    """
0254    This represents a SQLObject class that is bound to a specific
0255    connection (instances have a connection instance variable, but
0256    classes are global, so this is binds the connection variable
0257    lazily when a class method is accessed)
0258    """
0259    # @@: methods that take connection arguments should be explicitly
0260    # marked up instead of the implicit use of a connection argument
0261    # and inspect.getargspec()
0262
0263    def __init__(self, soClass, connection):
0264        self._soClass = soClass
0265        self._connection = connection
0266
0267    def __call__(self, *args, **kw):
0268        kw['connection'] = self._connection
0269        return self._soClass(*args, **kw)
0270
0271    def __getattr__(self, attr):
0272        meth = getattr(self._soClass, attr)
0273        if not isinstance(meth, types.MethodType):
0274            # We don't need to wrap non-methods
0275            return meth
0276        try:
0277            takes_conn = meth.takes_connection
0278        except AttributeError:
0279            args, varargs, varkw, defaults = inspect.getargspec(meth)
0280            assert not varkw and not varargs, (
0281                "I cannot tell whether I must wrap this method, "
0282                "because it takes **kw: %r"
0283                % meth)
0284            takes_conn = 'connection' in args
0285            meth.im_func.takes_connection = takes_conn
0286        if not takes_conn:
0287            return meth
0288        return ConnMethodWrapper(meth, self._connection)
0289
0290class ConnMethodWrapper(object):
0291
0292    def __init__(self, method, connection):
0293        self._method = method
0294        self._connection = connection
0295
0296    def __getattr__(self, attr):
0297        return getattr(self._method, attr)
0298
0299    def __call__(self, *args, **kw):
0300        kw['connection'] = self._connection
0301        return self._method(*args, **kw)
0302
0303    def __repr__(self):
0304        return '<Wrapped %r with connection %r>' % (
0305            self._method, self._connection)
0306
0307class DBAPI(DBConnection):
0308
0309    """
0310    Subclass must define a `makeConnection()` method, which
0311    returns a newly-created connection object.
0312
0313    ``queryInsertID`` must also be defined.
0314    """
0315
0316    dbName = None
0317
0318    def __init__(self, **kw):
0319        self._pool = []
0320        self._poolLock = threading.Lock()
0321        DBConnection.__init__(self, **kw)
0322        self._binaryType = type(self.module.Binary(''))
0323
0324    def _runWithConnection(self, meth, *args):
0325        conn = self.getConnection()
0326        try:
0327            val = meth(conn, *args)
0328        finally:
0329            self.releaseConnection(conn)
0330        return val
0331
0332    def getConnection(self):
0333        self._poolLock.acquire()
0334        try:
0335            if not self._pool:
0336                conn = self.makeConnection()
0337                self._connectionNumbers[id(conn)] = self._connectionCount
0338                self._connectionCount += 1
0339            else:
0340                conn = self._pool.pop()
0341            if self.debug:
0342                s = 'ACQUIRE'
0343                if self._pool is not None:
0344                    s += ' pool=[%s]' % ', '.join([str(self._connectionNumbers[id(v)]) for v in self._pool])
0345                self.printDebug(conn, s, 'Pool')
0346            return conn
0347        finally:
0348            self._poolLock.release()
0349
0350    def releaseConnection(self, conn, explicit=False):
0351        if self.debug:
0352            if explicit:
0353                s = 'RELEASE (explicit)'
0354            else:
0355                s = 'RELEASE (implicit, autocommit=%s)' % self.autoCommit
0356            if self._pool is None:
0357                s += ' no pooling'
0358            else:
0359                s += ' pool=[%s]' % ', '.join([str(self._connectionNumbers[id(v)]) for v in self._pool])
0360            self.printDebug(conn, s, 'Pool')
0361        if self.supportTransactions and not explicit:
0362            if self.autoCommit == 'exception':
0363                if self.debug:
0364                    self.printDebug(conn, 'auto/exception', 'ROLLBACK')
0365                conn.rollback()
0366                raise Exception('Object used outside of a transaction; implicit COMMIT or ROLLBACK not allowed')
0367            elif self.autoCommit:
0368                if self.debug:
0369                    self.printDebug(conn, 'auto', 'COMMIT')
0370                if not getattr(conn, 'autocommit', False):
0371                    conn.commit()
0372            else:
0373                if self.debug:
0374                    self.printDebug(conn, 'auto', 'ROLLBACK')
0375                conn.rollback()
0376        if self._pool is not None:
0377            if conn not in self._pool:
0378                # @@: We can get duplicate releasing of connections with
0379                # the __del__ in Iteration (unfortunately, not sure why
0380                # it happens)
0381                self._pool.insert(0, conn)
0382        else:
0383            conn.close()
0384
0385    def printDebug(self, conn, s, name, type='query'):
0386        if name == 'Pool' and self.debug != 'Pool':
0387            return
0388        if type == 'query':
0389            sep = ': '
0390        else:
0391            sep = '->'
0392            s = repr(s)
0393        n = self._connectionNumbers[id(conn)]
0394        spaces = ' '*(8-len(name))
0395        if self.debugThreading:
0396            threadName = threading.currentThread().getName()
0397            threadName = (':' + threadName + ' '*(8-len(threadName)))
0398        else:
0399            threadName = ''
0400        msg = '%(n)2i%(threadName)s/%(name)s%(spaces)s%(sep)s %(s)s' % locals()
0401        self.debugWriter.write(msg)
0402
0403    def _executeRetry(self, conn, cursor, query):
0404        if self.debug:
0405            self.printDebug(conn, query, 'QueryR')
0406        return cursor.execute(query)
0407
0408    def _query(self, conn, s):
0409        if self.debug:
0410            self.printDebug(conn, s, 'Query')
0411        self._executeRetry(conn, conn.cursor(), s)
0412
0413    def query(self, s):
0414        return self._runWithConnection(self._query, s)
0415
0416    def _queryAll(self, conn, s):
0417        if self.debug:
0418            self.printDebug(conn, s, 'QueryAll')
0419        c = conn.cursor()
0420        self._executeRetry(conn, c, s)
0421        value = c.fetchall()
0422        if self.debugOutput:
0423            self.printDebug(conn, value, 'QueryAll', 'result')
0424        return value
0425
0426    def queryAll(self, s):
0427        return self._runWithConnection(self._queryAll, s)
0428
0429    def _queryAllDescription(self, conn, s):
0430        """
0431        Like queryAll, but returns (description, rows), where the
0432        description is cursor.description (which gives row types)
0433        """
0434        if self.debug:
0435            self.printDebug(conn, s, 'QueryAllDesc')
0436        c = conn.cursor()
0437        self._executeRetry(conn, c, s)
0438        value = c.fetchall()
0439        if self.debugOutput:
0440            self.printDebug(conn, value, 'QueryAll', 'result')
0441        return c.description, value
0442
0443    def queryAllDescription(self, s):
0444        return self._runWithConnection(self._queryAllDescription, s)
0445
0446    def _queryOne(self, conn, s):
0447        if self.debug:
0448            self.printDebug(conn, s, 'QueryOne')
0449        c = conn.cursor()
0450        self._executeRetry(conn, c, s)
0451        value = c.fetchone()
0452        if self.debugOutput:
0453            self.printDebug(conn, value, 'QueryOne', 'result')
0454        return value
0455
0456    def queryOne(self, s):
0457        return self._runWithConnection(self._queryOne, s)
0458
0459    def _insertSQL(self, table, names, values):
0460        return ("INSERT INTO %s (%s) VALUES (%s)" %
0461                (table, ', '.join(names),
0462                 ', '.join([self.sqlrepr(v) for v in values])))
0463
0464    def transaction(self):
0465        return Transaction(self)
0466
0467    def queryInsertID(self, soInstance, id, names, values):
0468        return self._runWithConnection(self._queryInsertID, soInstance, id, names, values)
0469
0470    def iterSelect(self, select):
0471        return select.IterationClass(self, self.getConnection(),
0472                         select, keepConnection=False)
0473
0474    def accumulateSelect(self, select, *expressions):
0475        """ Apply an accumulate function(s) (SUM, COUNT, MIN, AVG, MAX, etc...)
0476            to the select object.
0477        """
0478        q = select.queryForSelect().newItems(expressions).unlimited().orderBy(None)
0479        q = self.sqlrepr(q)
0480        val = self.queryOne(q)
0481        if len(expressions) == 1:
0482            val = val[0]
0483        return val
0484
0485    def queryForSelect(self, select):
0486        return self.sqlrepr(select.queryForSelect())
0487
0488    def _SO_createJoinTable(self, join):
0489        self.query(self._SO_createJoinTableSQL(join))
0490
0491    def _SO_createJoinTableSQL(self, join):
0492        return ('CREATE TABLE %s (\n%s %s,\n%s %s\n)' %
0493                (join.intermediateTable,
0494                 join.joinColumn,
0495                 self.joinSQLType(join),
0496                 join.otherColumn,
0497                 self.joinSQLType(join)))
0498
0499    def _SO_dropJoinTable(self, join):
0500        self.query("DROP TABLE %s" % join.intermediateTable)
0501
0502    def _SO_createIndex(self, soClass, index):
0503        self.query(self.createIndexSQL(soClass, index))
0504
0505    def createIndexSQL(self, soClass, index):
0506        assert 0, 'Implement in subclasses'
0507
0508    def createTable(self, soClass):
0509        createSql, constraints = self.createTableSQL(soClass)
0510        self.query(createSql)
0511
0512        return constraints
0513
0514    def createReferenceConstraints(self, soClass):
0515        refConstraints = [self.createReferenceConstraint(soClass, column)                             for column in soClass.sqlmeta.columnList                             if isinstance(column, col.SOForeignKey)]
0518        refConstraintDefs = [constraint                                for constraint in refConstraints                                if constraint]
0521        return refConstraintDefs
0522
0523    def createSQL(self, soClass):
0524        tableCreateSQLs = getattr(soClass.sqlmeta, 'createSQL', None)
0525        if tableCreateSQLs:
0526            assert isinstance(tableCreateSQLs,(str,list,dict,tuple)), (
0527                '%s.sqlmeta.createSQL must be a str, list, dict or tuple.' %
0528                (soClass.__name__))
0529            if isinstance(tableCreateSQLs, dict):
0530                tableCreateSQLs = tableCreateSQLs.get(soClass._connection.dbName, [])
0531            if isinstance(tableCreateSQLs, str):
0532                tableCreateSQLs = [tableCreateSQLs]
0533            if isinstance(tableCreateSQLs, tuple):
0534                tableCreateSQLs = list(tableCreateSQLs)
0535            assert isinstance(tableCreateSQLs,list), (
0536                'Unable to create a list from %s.sqlmeta.createSQL' %
0537                (soClass.__name__))
0538        return tableCreateSQLs or []
0539
0540    def createTableSQL(self, soClass):
0541        constraints = self.createReferenceConstraints(soClass)
0542        extraSQL = self.createSQL(soClass)
0543        createSql = ('CREATE TABLE %s (\n%s\n)' %
0544                (soClass.sqlmeta.table, self.createColumns(soClass)))
0545        return createSql, constraints + extraSQL
0546
0547    def createColumns(self, soClass):
0548        columnDefs = [self.createIDColumn(soClass)]                        + [self.createColumn(soClass, col)
0550                        for col in soClass.sqlmeta.columnList]
0551        return ",\n".join(["    %s" % c for c in columnDefs])
0552
0553    def createReferenceConstraint(self, soClass, col):
0554        assert 0, "Implement in subclasses"
0555
0556    def createColumn(self, soClass, col):
0557        assert 0, "Implement in subclasses"
0558
0559    def dropTable(self, tableName, cascade=False):
0560        self.query("DROP TABLE %s" % tableName)
0561
0562    def clearTable(self, tableName):
0563        # 3-03 @@: Should this have a WHERE 1 = 1 or similar
0564        # clause?  In some configurations without the WHERE clause
0565        # the query won't go through, but maybe we shouldn't override
0566        # that.
0567        self.query("DELETE FROM %s" % tableName)
0568
0569    def createBinary(self, value):
0570        """
0571        Create a binary object wrapper for the given database.
0572        """
0573        # Default is Binary() function from the connection driver.
0574        return self.module.Binary(value)
0575
0576    # The _SO_* series of methods are sorts of "friend" methods
0577    # with SQLObject.  They grab values from the SQLObject instances
0578    # or classes freely, but keep the SQLObject class from accessing
0579    # the database directly.  This way no SQL is actually created
0580    # in the SQLObject class.
0581
0582    def _SO_update(self, so, values):
0583        self.query("UPDATE %s SET %s WHERE %s = (%s)" %
0584                   (so.sqlmeta.table,
0585                    ", ".join(["%s = (%s)" % (dbName, self.sqlrepr(value))
0586                               for dbName, value in values]),
0587                    so.sqlmeta.idName,
0588                    self.sqlrepr(so.id)))
0589
0590    def _SO_selectOne(self, so, columnNames):
0591        return self._SO_selectOneAlt(so, columnNames, so.q.id==so.id)
0592
0593
0594    def _SO_selectOneAlt(self, so, columnNames, condition):
0595        if columnNames:
0596            columns = [isinstance(x, basestring) and sqlbuilder.SQLConstant(x) or x for x in columnNames]
0597        else:
0598            columns = None
0599        return self.queryOne(self.sqlrepr(sqlbuilder.Select(columns,
0600                                                            staticTables=[so.sqlmeta.table],
0601                                                            clause=condition)))
0602
0603    def _SO_delete(self, so):
0604        self.query("DELETE FROM %s WHERE %s = (%s)" %
0605                   (so.sqlmeta.table,
0606                    so.sqlmeta.idName,
0607                    self.sqlrepr(so.id)))
0608
0609    def _SO_selectJoin(self, soClass, column, value):
0610        return self.queryAll("SELECT %s FROM %s WHERE %s = (%s)" %
0611                             (soClass.sqlmeta.idName,
0612                              soClass.sqlmeta.table,
0613                              column,
0614                              self.sqlrepr(value)))
0615
0616    def _SO_intermediateJoin(self, table, getColumn, joinColumn, value):
0617        return self.queryAll("SELECT %s FROM %s WHERE %s = (%s)" %
0618                             (getColumn,
0619                              table,
0620                              joinColumn,
0621                              self.sqlrepr(value)))
0622
0623    def _SO_intermediateDelete(self, table, firstColumn, firstValue,
0624                               secondColumn, secondValue):
0625        self.query("DELETE FROM %s WHERE %s = (%s) AND %s = (%s)" %
0626                   (table,
0627                    firstColumn,
0628                    self.sqlrepr(firstValue),
0629                    secondColumn,
0630                    self.sqlrepr(secondValue)))
0631
0632    def _SO_intermediateInsert(self, table, firstColumn, firstValue,
0633                               secondColumn, secondValue):
0634        self.query("INSERT INTO %s (%s, %s) VALUES (%s, %s)" %
0635                   (table,
0636                    firstColumn,
0637                    secondColumn,
0638                    self.sqlrepr(firstValue),
0639                    self.sqlrepr(secondValue)))
0640
0641    def _SO_columnClause(self, soClass, kw):
0642        ops = {None: "IS"}
0643        data = []
0644        if 'id' in kw:
0645            data.append((soClass.sqlmeta.idName, kw.pop('id')))
0646        for soColumn in soClass.sqlmeta.columnList:
0647            key = soColumn.name
0648            if key in kw:
0649                val = kw.pop(key)
0650                if soColumn.from_python:
0651                    val = soColumn.from_python(val, sqlbuilder.SQLObjectState(soClass, connection=self))
0652                data.append((soColumn.dbName, val))
0653            elif soColumn.foreignName in kw:
0654                obj = kw.pop(soColumn.foreignName)
0655                if isinstance(obj, main.SQLObject):
0656                    data.append((soColumn.dbName, obj.id))
0657                else:
0658                    data.append((soColumn.dbName, obj))
0659        if kw:
0660            # pick the first key from kw to use to raise the error,
0661            raise TypeError("got an unexpected keyword argument(s): %r" % kw.keys())
0662
0663        if not data:
0664            return None
0665        return ' AND '.join(
0666            ['%s %s %s' %
0667             (dbName, ops.get(value, "="), self.sqlrepr(value))
0668             for dbName, value
0669             in data])
0670
0671    def sqlrepr(self, v):
0672        return sqlrepr(v, self.dbName)
0673
0674    def __del__(self):
0675        self.close()
0676
0677    def close(self):
0678        if not hasattr(self, '_pool'):
0679            # Probably there was an exception while creating this
0680            # instance, so it is incomplete.
0681            return
0682        if not self._pool:
0683            return
0684        self._poolLock.acquire()
0685        try:
0686            if not self._pool: # _pool could be filled in a different thread
0687                return
0688            conns = self._pool[:]
0689            self._pool[:] = []
0690            for conn in conns:
0691                try:
0692                    conn.close()
0693                except self.module.Error:
0694                    pass
0695            del conn
0696            del conns
0697        finally:
0698            self._poolLock.release()
0699
0700    def createEmptyDatabase(self):
0701        """
0702        Create an empty database.
0703        """
0704        raise NotImplementedError
0705
0706class Iteration(object):
0707
0708    def __init__(self, dbconn, rawconn, select, keepConnection=False):
0709        self.dbconn = dbconn
0710        self.rawconn = rawconn
0711        self.select = select
0712        self.keepConnection = keepConnection
0713        self.cursor = rawconn.cursor()
0714        self.query = self.dbconn.queryForSelect(select)
0715        if dbconn.debug:
0716            dbconn.printDebug(rawconn, self.query, 'Select')
0717        self.dbconn._executeRetry(self.rawconn, self.cursor, self.query)
0718
0719    def __iter__(self):
0720        return self
0721
0722    def next(self):
0723        result = self.cursor.fetchone()
0724        if result is None:
0725            self._cleanup()
0726            raise StopIteration
0727        if result[0] is None:
0728            return None
0729        if self.select.ops.get('lazyColumns', 0):
0730            obj = self.select.sourceClass.get(result[0], connection=self.dbconn)
0731            return obj
0732        else:
0733            obj = self.select.sourceClass.get(result[0], selectResults=result[1:], connection=self.dbconn)
0734            return obj
0735
0736    def _cleanup(self):
0737        if getattr(self, 'query', None) is None:
0738            # already cleaned up
0739            return
0740        self.query = None
0741        if not self.keepConnection:
0742            self.dbconn.releaseConnection(self.rawconn)
0743        self.dbconn = self.rawconn = self.select = self.cursor = None
0744
0745    def __del__(self):
0746        self._cleanup()
0747
0748class Transaction(object):
0749
0750    def __init__(self, dbConnection):
0751        # this is to skip __del__ in case of an exception in this __init__
0752        self._obsolete = True
0753        self._dbConnection = dbConnection
0754        self._connection = dbConnection.getConnection()
0755        self._dbConnection._setAutoCommit(self._connection, 0)
0756        self.cache = CacheSet(cache=dbConnection.doCache)
0757        self._deletedCache = {}
0758        self._obsolete = False
0759
0760    def assertActive(self):
0761        assert not self._obsolete, "This transaction has already gone through ROLLBACK; begin another transaction"
0762
0763    def query(self, s):
0764        self.assertActive()
0765        return self._dbConnection._query(self._connection, s)
0766
0767    def queryAll(self, s):
0768        self.assertActive()
0769        return self._dbConnection._queryAll(self._connection, s)
0770
0771    def queryOne(self, s):
0772        self.assertActive()
0773        return self._dbConnection._queryOne(self._connection, s)
0774
0775    def queryInsertID(self, soInstance, id, names, values):
0776        self.assertActive()
0777        return self._dbConnection._queryInsertID(
0778            self._connection, soInstance, id, names, values)
0779
0780    def iterSelect(self, select):
0781        self.assertActive()
0782        # We can't keep the cursor open with results in a transaction,
0783        # because we might want to use the connection while we're
0784        # still iterating through the results.
0785        # @@: But would it be okay for psycopg, with threadsafety
0786        # level 2?
0787        return iter(list(select.IterationClass(self, self._connection,
0788                                   select, keepConnection=True)))
0789
0790    def _SO_delete(self, inst):
0791        cls = inst.__class__.__name__
0792        if not cls in self._deletedCache:
0793            self._deletedCache[cls] = []
0794        self._deletedCache[cls].append(inst.id)
0795        meth = new.instancemethod(self._dbConnection._SO_delete.im_func, self, self.__class__)
0796        return meth(inst)
0797
0798    def commit(self, close=False):
0799        if self._obsolete:
0800            # @@: is it okay to get extraneous commits?
0801            return
0802        if self._dbConnection.debug:
0803            self._dbConnection.printDebug(self._connection, '', 'COMMIT')
0804        self._connection.commit()
0805        subCaches = [(sub[0], sub[1].allIDs()) for sub in self.cache.allSubCachesByClassNames().items()]
0806        subCaches.extend([(x[0], x[1]) for x in self._deletedCache.items()])
0807        for cls, ids in subCaches:
0808            for id in ids:
0809                inst = self._dbConnection.cache.tryGetByName(id, cls)
0810                if inst is not None:
0811                    inst.expire()
0812        if close:
0813            self._makeObsolete()
0814
0815    def rollback(self):
0816        if self._obsolete:
0817            # @@: is it okay to get extraneous rollbacks?
0818            return
0819        if self._dbConnection.debug:
0820            self._dbConnection.printDebug(self._connection, '', 'ROLLBACK')
0821        subCaches = [(sub, sub.allIDs()) for sub in self.cache.allSubCaches()]
0822        self._connection.rollback()
0823
0824        for subCache, ids in subCaches:
0825            for id in ids:
0826                inst = subCache.tryGet(id)
0827                if inst is not None:
0828                    inst.expire()
0829        self._makeObsolete()
0830
0831    def __getattr__(self, attr):
0832        """
0833        If nothing else works, let the parent connection handle it.
0834        Except with this transaction as 'self'.  Poor man's
0835        acquisition?  Bad programming?  Okay, maybe.
0836        """
0837        self.assertActive()
0838        attr = getattr(self._dbConnection, attr)
0839        try:
0840            func = attr.im_func
0841        except AttributeError:
0842            if isinstance(attr, ConnWrapper):
0843                return ConnWrapper(attr._soClass, self)
0844            else:
0845                return attr
0846        else:
0847            meth = new.instancemethod(func, self, self.__class__)
0848            return meth
0849
0850    def _makeObsolete(self):
0851        self._obsolete = True
0852        if self._dbConnection.autoCommit:
0853            self._dbConnection._setAutoCommit(self._connection, 1)
0854        self._dbConnection.releaseConnection(self._connection,
0855                                             explicit=True)
0856        self._connection = None
0857        self._deletedCache = {}
0858
0859    def begin(self):
0860        # @@: Should we do this, or should begin() be a no-op when we're
0861        # not already obsolete?
0862        assert self._obsolete, "You cannot begin a new transaction session without rolling back this one"
0863        self._obsolete = False
0864        self._connection = self._dbConnection.getConnection()
0865        self._dbConnection._setAutoCommit(self._connection, 0)
0866
0867    def __del__(self):
0868        if self._obsolete:
0869            return
0870        self.rollback()
0871
0872    def close(self):
0873        raise TypeError('You cannot just close transaction - you should either call rollback(), commit() or commit(close=True) to close the underlying connection.')
0874
0875class ConnectionHub(object):
0876
0877    """
0878    This object serves as a hub for connections, so that you can pass
0879    in a ConnectionHub to a SQLObject subclass as though it was a
0880    connection, but actually bind a real database connection later.
0881    You can also bind connections on a per-thread basis.
0882
0883    You must hang onto the original ConnectionHub instance, as you
0884    cannot retrieve it again from the class or instance.
0885
0886    To use the hub, do something like::
0887
0888        hub = ConnectionHub()
0889        class MyClass(SQLObject):
0890            _connection = hub
0891
0892        hub.threadConnection = connectionFromURI('...')
0893
0894    """
0895
0896    def __init__(self):
0897        self.threadingLocal = threading_local()
0898
0899    def __get__(self, obj, type=None):
0900        # I'm a little surprised we have to do this, but apparently
0901        # the object's private dictionary of attributes doesn't
0902        # override this descriptor.
0903        if (obj is not None) and '_connection' in obj.__dict__:
0904            return obj.__dict__['_connection']
0905        return self.getConnection()
0906
0907    def __set__(self, obj, value):
0908        obj.__dict__['_connection'] = value
0909
0910    def getConnection(self):
0911        try:
0912            return self.threadingLocal.connection
0913        except AttributeError:
0914            try:
0915                return self.processConnection
0916            except AttributeError:
0917                raise AttributeError(
0918                    "No connection has been defined for this thread "
0919                    "or process")
0920
0921    def doInTransaction(self, func, *args, **kw):
0922        """
0923        This routine can be used to run a function in a transaction,
0924        rolling the transaction back if any exception is raised from
0925        that function, and committing otherwise.
0926
0927        Use like::
0928
0929            sqlhub.doInTransaction(process_request, os.environ)
0930
0931        This will run ``process_request(os.environ)``.  The return
0932        value will be preserved.
0933        """
0934        # @@: In Python 2.5, something usable with with: should also
0935        # be added.
0936        try:
0937            old_conn = self.threadingLocal.connection
0938            old_conn_is_threading = True
0939        except AttributeError:
0940            old_conn = self.processConnection
0941            old_conn_is_threading = False
0942        conn = old_conn.transaction()
0943        if old_conn_is_threading:
0944            self.threadConnection = conn
0945        else:
0946            self.processConnection = conn
0947        try:
0948            try:
0949                value = func(*args, **kw)
0950            except:
0951                conn.rollback()
0952                raise
0953            else:
0954                conn.commit(close=True)
0955                return value
0956        finally:
0957            if old_conn_is_threading:
0958                self.threadConnection = old_conn
0959            else:
0960                self.processConnection = old_conn
0961
0962    def _set_threadConnection(self, value):
0963        self.threadingLocal.connection = value
0964
0965    def _get_threadConnection(self):
0966        return self.threadingLocal.connection
0967
0968    def _del_threadConnection(self):
0969        del self.threadingLocal.connection
0970
0971    threadConnection = property(_get_threadConnection,
0972                                _set_threadConnection,
0973                                _del_threadConnection)
0974
0975class ConnectionURIOpener(object):
0976
0977    def __init__(self):
0978        self.schemeBuilders = {}
0979        self.instanceNames = {}
0980        self.cachedURIs = {}
0981
0982    def registerConnection(self, schemes, builder):
0983        for uriScheme in schemes:
0984            assert not uriScheme in self.schemeBuilders                      or self.schemeBuilders[uriScheme] is builder,                      "A driver has already been registered for the URI scheme %s" % uriScheme
0987            self.schemeBuilders[uriScheme] = builder
0988
0989    def registerConnectionInstance(self, inst):
0990        if inst.name:
0991            assert not inst.name in self.instanceNames                      or self.instanceNames[inst.name] is cls,                      "A instance has already been registered with the name %s" % inst.name
0994            assert inst.name.find(':') == -1, "You cannot include ':' in your class names (%r)" % cls.name
0995            self.instanceNames[inst.name] = inst
0996
0997    def connectionForURI(self, uri, oldUri=False, **args):
0998        if args:
0999            if '?' not in uri:
1000                uri += '?' + urllib.urlencode(args)
1001            else:
1002                uri += '&' + urllib.urlencode(args)
1003        if uri in self.cachedURIs:
1004            return self.cachedURIs[uri]
1005        if uri.find(':') != -1:
1006            scheme, rest = uri.split(':', 1)
1007            connCls = self.dbConnectionForScheme(scheme)
1008            if oldUri:
1009                conn = connCls.connectionFromOldURI(uri)
1010            else:
1011                conn = connCls.connectionFromURI(uri)
1012        else:
1013            # We just have a name, not a URI
1014            assert uri in self.instanceNames,                      "No SQLObject driver exists under the name %s" % uri
1016            conn = self.instanceNames[uri]
1017        # @@: Do we care if we clobber another connection?
1018        self.cachedURIs[uri] = conn
1019        return conn
1020
1021    def dbConnectionForScheme(self, scheme):
1022        assert scheme in self.schemeBuilders, (
1023               "No SQLObject driver exists for %s (only %s)"
1024               % (scheme, ', '.join(self.schemeBuilders.keys())))
1025        return self.schemeBuilders[scheme]()
1026
1027TheURIOpener = ConnectionURIOpener()
1028
1029registerConnection = TheURIOpener.registerConnection
1030registerConnectionInstance = TheURIOpener.registerConnectionInstance
1031connectionForURI = TheURIOpener.connectionForURI
1032dbConnectionForScheme = TheURIOpener.dbConnectionForScheme
1033
1034# Register DB URI schemas
1035import firebird
1036import maxdb
1037import mssql
1038import mysql
1039import postgres
1040import rdbhost
1041import sqlite
1042import sybase