0001import threading
0002from util.threadinglocal import local as threading_local
0003import sys
0004import re
0005import warnings
0006import atexit
0007import os
0008import new
0009import types
0010import urllib
0011import weakref
0012import inspect
0013import sqlbuilder
0014from cache import CacheSet
0015import col
0016import main
0017from joins import sorter
0018from converters import sqlrepr
0019import classregistry
0020
0021warnings.filterwarnings("ignore", "DB-API extension cursor.lastrowid used")
0022
0023_connections = {}
0024
0025def _closeConnection(ref):
0026    conn = ref()
0027    if conn is not None:
0028        conn.close()
0029
0030class ConsoleWriter:
0031    def __init__(self, loglevel):
0032        self.loglevel = loglevel
0033        self.logfile = getattr(sys, loglevel or "stdout")
0034    def write(self, text):
0035        self.logfile.write(text + '\n')
0036
0037class LogWriter:
0038    def __init__(self, logger, loglevel):
0039        self.logger = logger
0040        self.loglevel = loglevel
0041        self.logmethod = getattr(logger, loglevel)
0042    def write(self, text):
0043        self.logmethod(text)
0044
0045def makeDebugWriter(loggerName, loglevel):
0046    if not loggerName:
0047        return ConsoleWriter(loglevel)
0048    import logging
0049    logger = logging.getLogger(loggerName)
0050    return LogWriter(logger, loglevel)
0051
0052class DBConnection:
0053
0054    def __init__(self, name=None, debug=False, debugOutput=False,
0055                 cache=True, style=None, autoCommit=True,
0056                 debugThreading=False, registry=None,
0057                 logger=None, loglevel=None):
0058        self.name = name
0059        self.debug = debug
0060        self.debugOutput = debugOutput
0061        self.debugThreading = debugThreading
0062        self.debugWriter = makeDebugWriter(logger, loglevel)
0063        self.cache = CacheSet(cache=cache)
0064        self.doCache = cache
0065        self.style = style
0066        self._connectionNumbers = {}
0067        self._connectionCount = 1
0068        self.autoCommit = autoCommit
0069        self.registry = registry or None
0070        classregistry.registry(self.registry).addCallback(
0071            self.soClassAdded)
0072        registerConnectionInstance(self)
0073        atexit.register(_closeConnection, weakref.ref(self))
0074
0075    def uri(self):
0076        auth = getattr(self, 'user', None) or ''
0077        if auth:
0078            if self.password:
0079                auth = auth + ':' + self.password
0080            auth = auth + '@'
0081        else:
0082            assert not getattr(self, 'password', None), (
0083                'URIs cannot express passwords without usernames')
0084        uri = '%s://%s' % (self.dbName, auth)
0085        if self.host:
0086            uri += self.host
0087        uri += '/'
0088        db = self.db
0089        if db.startswith('/'):
0090            db = path[1:]
0091        return uri + db
0092
0093    def isSupported(cls):
0094        raise NotImplemented
0095    isSupported = classmethod(isSupported)
0096
0097    def connectionFromURI(cls, uri):
0098        raise NotImplemented
0099    connectionFromURI = classmethod(connectionFromURI)
0100
0101    def _parseURI(uri):
0102        schema, rest = uri.split(':', 1)
0103        assert rest.startswith('/'), "URIs must start with scheme:/ -- you did not include a / (in %r)" % rest
0104        if rest.startswith('/') and not rest.startswith('//'):
0105            host = None
0106            rest = rest[1:]
0107        elif rest.startswith('///'):
0108            host = None
0109            rest = rest[3:]
0110        else:
0111            rest = rest[2:]
0112            if rest.find('/') == -1:
0113                host = rest
0114                rest = ''
0115            else:
0116                host, rest = rest.split('/', 1)
0117        if host and host.find('@') != -1:
0118            user = host[:host.rfind('@')] # Python 2.3 doesn't have .rsplit()
0119            host = host[host.rfind('@')+1:] # !!!
0120            if user.find(':') != -1:
0121                user, password = user.split(':', 1)
0122            else:
0123                password = None
0124        else:
0125            user = password = None
0126        if host and host.find(':') != -1:
0127            _host, port = host.split(':')
0128            try:
0129                port = int(port)
0130            except ValueError:
0131                raise ValueError, "port must be integer, got '%s' instead" % port
0132            if not (1 <= port <= 65535):
0133                raise ValueError, "port must be integer in the range 1-65535, got '%d' instead" % port
0134            host = _host
0135        else:
0136            port = None
0137        path = '/' + rest
0138        if os.name == 'nt':
0139            if (len(rest) > 1) and (rest[1] == '|'):
0140                path = "%s:%s" % (rest[0], rest[2:])
0141        args = {}
0142        if path.find('?') != -1:
0143            path, arglist = path.split('?', 1)
0144            arglist = arglist.split('&')
0145            for single in arglist:
0146                argname, argvalue = single.split('=', 1)
0147                argvalue = urllib.unquote(argvalue)
0148                args[argname] = argvalue
0149        return user, password, host, port, path, args
0150    _parseURI = staticmethod(_parseURI)
0151
0152    def soClassAdded(self, soClass):
0153        """
0154        This is called for each new class; we use this opportunity
0155        to create an instance method that is bound to the class
0156        and this connection.
0157        """
0158        name = soClass.__name__
0159        assert not hasattr(self, name), (
0160            "Connection %r already has an attribute with the name "
0161            "%r (and you just created the conflicting class %r)"
0162            % (self, name, soClass))
0163        setattr(self, name, ConnWrapper(soClass, self))
0164
0165    def expireAll(self):
0166        """
0167        Expire all instances of objects for this connection.
0168        """
0169        cache_set = self.cache
0170        cache_set.weakrefAll()
0171        for item in cache_set.getAll():
0172            item.expire()
0173
0174class ConnWrapper(object):
0175
0176    """
0177    This represents a SQLObject class that is bound to a specific
0178    connection (instances have a connection instance variable, but
0179    classes are global, so this is binds the connection variable
0180    lazily when a class method is accessed)
0181    """
0182    # @@: methods that take connection arguments should be explicitly
0183    # marked up instead of the implicit use of a connection argument
0184    # and inspect.getargspec()
0185
0186    def __init__(self, soClass, connection):
0187        self._soClass = soClass
0188        self._connection = connection
0189
0190    def __call__(self, *args, **kw):
0191        kw['connection'] = self._connection
0192        return self._soClass(*args, **kw)
0193
0194    def __getattr__(self, attr):
0195        meth = getattr(self._soClass, attr)
0196        if not isinstance(meth, types.MethodType):
0197            # We don't need to wrap non-methods
0198            return meth
0199        try:
0200            takes_conn = meth.takes_connection
0201        except AttributeError:
0202            args, varargs, varkw, defaults = inspect.getargspec(meth)
0203            assert not varkw and not varargs, (
0204                "I cannot tell whether I must wrap this method, "
0205                "because it takes **kw: %r"
0206                % meth)
0207            takes_conn = 'connection' in args
0208            meth.im_func.takes_connection = takes_conn
0209        if not takes_conn:
0210            return meth
0211        return ConnMethodWrapper(meth, self._connection)
0212
0213class ConnMethodWrapper(object):
0214
0215    def __init__(self, method, connection):
0216        self._method = method
0217        self._connection = connection
0218
0219    def __getattr__(self, attr):
0220        return getattr(self._method, attr)
0221
0222    def __call__(self, *args, **kw):
0223        kw['connection'] = self._connection
0224        return self._method(*args, **kw)
0225
0226    def __repr__(self):
0227        return '<Wrapped %r with connection %r>' % (
0228            self._method, self._connection)
0229
0230class DBAPI(DBConnection):
0231
0232    """
0233    Subclass must define a `makeConnection()` method, which
0234    returns a newly-created connection object.
0235
0236    ``queryInsertID`` must also be defined.
0237    """
0238
0239    dbName = None
0240
0241    def __init__(self, **kw):
0242        self._pool = []
0243        self._poolLock = threading.Lock()
0244        DBConnection.__init__(self, **kw)
0245        self._binaryType = type(self.module.Binary(''))
0246
0247    def _runWithConnection(self, meth, *args):
0248        conn = self.getConnection()
0249        try:
0250            val = meth(conn, *args)
0251        finally:
0252            self.releaseConnection(conn)
0253        return val
0254
0255    def getConnection(self):
0256        self._poolLock.acquire()
0257        try:
0258            if not self._pool:
0259                conn = self.makeConnection()
0260                self._connectionNumbers[id(conn)] = self._connectionCount
0261                self._connectionCount += 1
0262            else:
0263                conn = self._pool.pop()
0264            if self.debug:
0265                s = 'ACQUIRE'
0266                if self._pool is not None:
0267                    s += ' pool=[%s]' % ', '.join([str(self._connectionNumbers[id(v)]) for v in self._pool])
0268                self.printDebug(conn, s, 'Pool')
0269            return conn
0270        finally:
0271            self._poolLock.release()
0272
0273    def releaseConnection(self, conn, explicit=False):
0274        if self.debug:
0275            if explicit:
0276                s = 'RELEASE (explicit)'
0277            else:
0278                s = 'RELEASE (implicit, autocommit=%s)' % self.autoCommit
0279            if self._pool is None:
0280                s += ' no pooling'
0281            else:
0282                s += ' pool=[%s]' % ', '.join([str(self._connectionNumbers[id(v)]) for v in self._pool])
0283            self.printDebug(conn, s, 'Pool')
0284        if self.supportTransactions and not explicit:
0285            if self.autoCommit == 'exception':
0286                if self.debug:
0287                    self.printDebug(conn, 'auto/exception', 'ROLLBACK')
0288                conn.rollback()
0289                raise Exception, 'Object used outside of a transaction; implicit COMMIT or ROLLBACK not allowed'
0290            elif self.autoCommit:
0291                if self.debug:
0292                    self.printDebug(conn, 'auto', 'COMMIT')
0293                if not getattr(conn, 'autocommit', False):
0294                    conn.commit()
0295            else:
0296                if self.debug:
0297                    self.printDebug(conn, 'auto', 'ROLLBACK')
0298                conn.rollback()
0299        if self._pool is not None:
0300            if conn not in self._pool:
0301                # @@: We can get duplicate releasing of connections with
0302                # the __del__ in Iteration (unfortunately, not sure why
0303                # it happens)
0304                self._pool.insert(0, conn)
0305        else:
0306            conn.close()
0307
0308    def printDebug(self, conn, s, name, type='query'):
0309        if name == 'Pool' and self.debug != 'Pool':
0310            return
0311        if type == 'query':
0312            sep = ': '
0313        else:
0314            sep = '->'
0315            s = repr(s)
0316        n = self._connectionNumbers[id(conn)]
0317        spaces = ' '*(8-len(name))
0318        if self.debugThreading:
0319            threadName = threading.currentThread().getName()
0320            threadName = (':' + threadName + ' '*(8-len(threadName)))
0321        else:
0322            threadName = ''
0323        msg = '%(n)2i%(threadName)s/%(name)s%(spaces)s%(sep)s %(s)s' % locals()
0324        self.debugWriter.write(msg)
0325
0326    def _executeRetry(self, conn, cursor, query):
0327        if self.debug:
0328            self.printDebug(conn, query, 'QueryR')
0329        return cursor.execute(query)
0330
0331    def _query(self, conn, s):
0332        if self.debug:
0333            self.printDebug(conn, s, 'Query')
0334        self._executeRetry(conn, conn.cursor(), s)
0335
0336    def query(self, s):
0337        return self._runWithConnection(self._query, s)
0338
0339    def _queryAll(self, conn, s):
0340        if self.debug:
0341            self.printDebug(conn, s, 'QueryAll')
0342        c = conn.cursor()
0343        self._executeRetry(conn, c, s)
0344        value = c.fetchall()
0345        if self.debugOutput:
0346            self.printDebug(conn, value, 'QueryAll', 'result')
0347        return value
0348
0349    def queryAll(self, s):
0350        return self._runWithConnection(self._queryAll, s)
0351
0352    def _queryAllDescription(self, conn, s):
0353        """
0354        Like queryAll, but returns (description, rows), where the
0355        description is cursor.description (which gives row types)
0356        """
0357        if self.debug:
0358            self.printDebug(conn, s, 'QueryAllDesc')
0359        c = conn.cursor()
0360        self._executeRetry(conn, c, s)
0361        value = c.fetchall()
0362        if self.debugOutput:
0363            self.printDebug(conn, value, 'QueryAll', 'result')
0364        return c.description, value
0365
0366    def queryAllDescription(self, s):
0367        return self._runWithConnection(self._queryAllDescription, s)
0368
0369    def _queryOne(self, conn, s):
0370        if self.debug:
0371            self.printDebug(conn, s, 'QueryOne')
0372        c = conn.cursor()
0373        self._executeRetry(conn, c, s)
0374        value = c.fetchone()
0375        if self.debugOutput:
0376            self.printDebug(conn, value, 'QueryOne', 'result')
0377        return value
0378
0379    def queryOne(self, s):
0380        return self._runWithConnection(self._queryOne, s)
0381
0382    def _insertSQL(self, table, names, values):
0383        return ("INSERT INTO %s (%s) VALUES (%s)" %
0384                (table, ', '.join(names),
0385                 ', '.join([self.sqlrepr(v) for v in values])))
0386
0387    def transaction(self):
0388        return Transaction(self)
0389
0390    def queryInsertID(self, soInstance, id, names, values):
0391        return self._runWithConnection(self._queryInsertID, soInstance, id, names, values)
0392
0393    def iterSelect(self, select):
0394        return select.IterationClass(self, self.getConnection(),
0395                         select, keepConnection=False)
0396
0397    def accumulateSelect(self, select, *expressions):
0398        """ Apply an accumulate function(s) (SUM, COUNT, MIN, AVG, MAX, etc...)
0399            to the select object.
0400        """
0401        q = select.queryForSelect().newItems(expressions).unlimited().orderBy(None)
0402        q = self.sqlrepr(q)
0403        val = self.queryOne(q)
0404        if len(expressions) == 1:
0405            val = val[0]
0406        return val
0407
0408    def queryForSelect(self, select):
0409        return self.sqlrepr(select.queryForSelect())
0410
0411    def _SO_createJoinTable(self, join):
0412        self.query(self._SO_createJoinTableSQL(join))
0413
0414    def _SO_createJoinTableSQL(self, join):
0415        return ('CREATE TABLE %s (\n%s %s,\n%s %s\n)' %
0416                (join.intermediateTable,
0417                 join.joinColumn,
0418                 self.joinSQLType(join),
0419                 join.otherColumn,
0420                 self.joinSQLType(join)))
0421
0422    def _SO_dropJoinTable(self, join):
0423        self.query("DROP TABLE %s" % join.intermediateTable)
0424
0425    def _SO_createIndex(self, soClass, index):
0426        self.query(self.createIndexSQL(soClass, index))
0427
0428    def createIndexSQL(self, soClass, index):
0429        assert 0, 'Implement in subclasses'
0430
0431    def createTable(self, soClass):
0432        createSql, constraints = self.createTableSQL(soClass)
0433        self.query(createSql)
0434
0435        return constraints
0436
0437    def createReferenceConstraints(self, soClass):
0438        refConstraints = [self.createReferenceConstraint(soClass, column)                             for column in soClass.sqlmeta.columnList                             if isinstance(column, col