0001import threading
0002from util.threadinglocal import local as threading_local
0003import sys
0004import re
0005import warnings
0006import atexit
0007import os
0008import new
0009import types
0010import urllib
0011import weakref
0012import inspect
0013import sqlbuilder
0014from cache import CacheSet
0015import col
0016import main
0017from joins import sorter
0018from converters import sqlrepr
0019import classregistry
0020
0021warnings.filterwarnings("ignore", "DB-API extension cursor.lastrowid used")
0022
0023_connections = {}
0024
0025def _closeConnection(ref):
0026    conn = ref()
0027    if conn is not None:
0028        conn.close()
0029
0030class ConsoleWriter:
0031    def __init__(self, loglevel):
0032        self.loglevel = loglevel
0033        self.logfile = getattr(sys, loglevel or "stdout")
0034    def write(self, text):
0035        self.logfile.write(text + '\n')
0036
0037class LogWriter:
0038    def __init__(self, logger, loglevel):
0039        self.logger = logger
0040        self.loglevel = loglevel
0041        self.logmethod = getattr(logger, loglevel)
0042    def write(self, text):
0043        self.logmethod(text)
0044
0045def makeDebugWriter(loggerName, loglevel):
0046    if not loggerName:
0047        return ConsoleWriter(loglevel)
0048    import logging
0049    logger = logging.getLogger(loggerName)
0050    return LogWriter(logger, loglevel)
0051
0052class Boolean(object):
0053    """A bool class that also understands some special string keywords (yes/no, true/false, on/off, 1/0)"""
0054    _keywords = {'1': True, 'yes': True, 'true': True, 'on': True,
0055                 '0': False, 'no': False, 'false': False, 'off': False}
0056    def __new__(cls, value):
0057        try:
0058            return Boolean._keywords[value.lower()]
0059        except (AttributeError, KeyError):
0060            return bool(value)
0061
0062class DBConnection:
0063
0064    def __init__(self, name=None, debug=False, debugOutput=False,
0065                 cache=True, style=None, autoCommit=True,
0066                 debugThreading=False, registry=None,
0067                 logger=None, loglevel=None):
0068        self.name = name
0069        self.debug = Boolean(debug)
0070        self.debugOutput = Boolean(debugOutput)
0071        self.debugThreading = Boolean(debugThreading)
0072        self.debugWriter = makeDebugWriter(logger, loglevel)
0073        self.doCache = Boolean(cache)
0074        self.cache = CacheSet(cache=self.doCache)
0075        self.style = style
0076        self._connectionNumbers = {}
0077        self._connectionCount = 1
0078        self.autoCommit = Boolean(autoCommit)
0079        self.registry = registry or None
0080        classregistry.registry(self.registry).addCallback(
0081            self.soClassAdded)
0082        registerConnectionInstance(self)
0083        atexit.register(_closeConnection, weakref.ref(self))
0084
0085    def uri(self):
0086        auth = getattr(self, 'user', '')
0087        if auth:
0088            if self.password:
0089                auth = auth + ':' + self.password
0090            auth = auth + '@'
0091        else:
0092            assert not getattr(self, 'password', None), (
0093                'URIs cannot express passwords without usernames')
0094        uri = '%s://%s' % (self.dbName, auth)
0095        if self.host:
0096            uri += self.host
0097        uri += '/'
0098        db = self.db
0099        if db.startswith('/'):
0100            db = path[1:]
0101        return uri + db
0102
0103    def isSupported(cls):
0104        raise NotImplemented
0105    isSupported = classmethod(isSupported)
0106
0107    def connectionFromURI(cls, uri):
0108        raise NotImplemented
0109    connectionFromURI = classmethod(connectionFromURI)
0110
0111    def _parseURI(uri):
0112        schema, rest = uri.split(':', 1)
0113        assert rest.startswith('/'), "URIs must start with scheme:/ -- you did not include a / (in %r)" % rest
0114        if rest.startswith('/') and not rest.startswith('//'):
0115            host = None
0116            rest = rest[1:]
0117        elif rest.startswith('///'):
0118            host = None
0119            rest = rest[3:]
0120        else:
0121            rest = rest[2:]
0122            if rest.find('/') == -1:
0123                host = rest
0124                rest = ''
0125            else:
0126                host, rest = rest.split('/', 1)
0127        if host and host.find('@') != -1:
0128            user, host = host.rsplit('@', 1)
0129            if user.find(':') != -1:
0130                user, password = user.split(':', 1)
0131            else:
0132                password = None
0133        else:
0134            user = password = None
0135        if host and host.find(':') != -1:
0136            _host, port = host.split(':')
0137            try:
0138                port = int(port)
0139            except ValueError:
0140                raise ValueError, "port must be integer, got '%s' instead" % port
0141            if not (1 <= port <= 65535):
0142                raise ValueError, "port must be integer in the range 1-65535, got '%d' instead" % port
0143            host = _host
0144        else:
0145            port = None
0146        path = '/' + rest
0147        if os.name == 'nt':
0148            if (len(rest) > 1) and (rest[1] == '|'):
0149                path = "%s:%s" % (rest[0], rest[2:])
0150        args = {}
0151        if path.find('?') != -1:
0152            path, arglist = path.split('?', 1)
0153            arglist = arglist.split('&')
0154            for single in arglist:
0155                argname, argvalue = single.split('=', 1)
0156                argvalue = urllib.unquote(argvalue)
0157                args[argname] = argvalue
0158        return user, password, host, port, path, args
0159    _parseURI = staticmethod(_parseURI)
0160
0161    def soClassAdded(self, soClass):
0162        """
0163        This is called for each new class; we use this opportunity
0164        to create an instance method that is bound to the class
0165        and this connection.
0166        """
0167        name = soClass.__name__
0168        assert not hasattr(self, name), (
0169            "Connection %r already has an attribute with the name "
0170            "%r (and you just created the conflicting class %r)"
0171            % (self, name, soClass))
0172        setattr(self, name, ConnWrapper(soClass, self))
0173
0174    def expireAll(self):
0175        """
0176        Expire all instances of objects for this connection.
0177        """
0178        cache_set = self.cache
0179        cache_set.weakrefAll()
0180        for item in cache_set.getAll():
0181            item.expire()
0182
0183class ConnWrapper(object):
0184
0185    """
0186    This represents a SQLObject class that is bound to a specific
0187    connection (instances have a connection instance variable, but
0188    classes are global, so this is binds the connection variable
0189    lazily when a class method is accessed)
0190    """
0191    # @@: methods that take connection arguments should be explicitly
0192    # marked up instead of the implicit use of a connection argument
0193    # and inspect.getargspec()
0194
0195    def __init__(self, soClass, connection):
0196        self._soClass = soClass
0197        self._connection = connection
0198
0199    def __call__(self, *args, **kw):
0200        kw['connection'] = self._connection
0201        return self._soClass(*args, **kw)
0202
0203    def __getattr__(self, attr):
0204        meth = getattr(self._soClass, attr)
0205        if not isinstance(meth, types.MethodType):
0206            # We don't need to wrap non-methods
0207            return meth
0208        try:
0209            takes_conn = meth.takes_connection
0210        except AttributeError:
0211            args, varargs, varkw, defaults = inspect.getargspec(meth)
0212            assert not varkw and not varargs, (
0213                "I cannot tell whether I must wrap this method, "
0214                "because it takes **kw: %r"
0215                % meth)
0216            takes_conn = 'connection' in args
0217            meth.im_func.takes_connection = takes_conn
0218        if not takes_conn:
0219            return meth
0220        return ConnMethodWrapper(meth, self._connection)
0221
0222class ConnMethodWrapper(object):
0223
0224    def __init__(self, method, connection):
0225        self._method = method
0226        self._connection = connection
0227
0228    def __getattr__(self, attr):
0229        return getattr(self._method, attr)
0230
0231    def __call__(self, *args, **kw):
0232        kw['connection'] = self._connection
0233        return self._method(*args, **kw)
0234
0235    def __repr__(self):
0236        return '<Wrapped %r with connection %r>' % (
0237            self._method, self._connection)
0238
0239class DBAPI(DBConnection):
0240
0241    """
0242    Subclass must define a `makeConnection()` method, which
0243    returns a newly-created connection object.
0244
0245    ``queryInsertID`` must also be defined.
0246    """
0247
0248    dbName = None
0249
0250    def __init__(self, **kw):
0251        self._pool = []
0252        self._poolLock = threading.Lock()
0253        DBConnection.__init__(self, **kw)
0254        self._binaryType = type(self.module.Binary(''))
0255
0256    def _runWithConnection(self, meth, *args):
0257        conn = self.getConnection()
0258        try:
0259            val = meth(conn, *args)
0260        finally:
0261            self.releaseConnection(conn)
0262        return val
0263
0264    def getConnection(self):
0265        self._poolLock.acquire()
0266        try:
0267            if not self._pool:
0268                conn = self.makeConnection()
0269                self._connectionNumbers[id(conn)] = self._connectionCount
0270                self._connectionCount += 1
0271            else:
0272                conn = self._pool.pop()
0273            if self.debug:
0274                s = 'ACQUIRE'
0275                if self._pool is not None:
0276                    s += ' pool=[%s]' % ', '.join([str(self._connectionNumbers[id(v)]) for v in self._pool])
0277                self.printDebug(conn, s, 'Pool')
0278            return conn
0279        finally:
0280            self._poolLock.release()
0281
0282    def releaseConnection(self, conn, explicit=False):
0283        if self.debug:
0284            if explicit:
0285                s = 'RELEASE (explicit)'
0286            else:
0287                s = 'RELEASE (implicit, autocommit=%s)' % self.autoCommit
0288            if self._pool is None:
0289                s += ' no pooling'
0290            else:
0291                s += ' pool=[%s]' % ', '.join([str(self._connectionNumbers[id(v)]) for v in self._pool])
0292            self.printDebug(conn, s, 'Pool')
0293        if self.supportTransactions and not explicit:
0294            if self.autoCommit == 'exception':
0295                if self.debug:
0296                    self.printDebug(conn, 'auto/exception', 'ROLLBACK')
0297                conn.rollback()
0298                raise Exception, 'Object used outside of a transaction; implicit COMMIT or ROLLBACK not allowed'
0299            elif self.autoCommit:
0300                if self.debug:
0301                    self.printDebug(conn, 'auto', 'COMMIT')
0302                if not getattr(conn, 'autocommit', False):
0303                    conn.commit()
0304            else:
0305                if self.debug:
0306                    self.printDebug(conn, 'auto', 'ROLLBACK')
0307                conn.rollback()
0308        if self._pool is not None:
0309            if conn not in self._pool:
0310                # @@: We can get duplicate releasing of connections with
0311                # the __del__ in Iteration (unfortunately, not sure why
0312                # it happens)
0313                self._pool.insert(0, conn)
0314        else:
0315            conn.close()
0316
0317    def printDebug(self, conn, s, name, type='query'):
0318        if name == 'Pool' and self.debug != 'Pool':
0319            return
0320        if type == 'query':
0321            sep = ': '
0322        else:
0323            sep = '->'
0324            s = repr(s)
0325        n = self._connectionNumbers[id(conn)]
0326        spaces = ' '*(8-len(name))
0327        if self.debugThreading:
0328            threadName = threading.currentThread().getName()
0329            threadName = (':' + threadName + ' '*(8-len(threadName)))
0330        else:
0331            threadName = ''
0332        msg = '%(n)2i%(threadName)s/%(name)s%(spaces)s%(sep)s %(s)s' % locals()
0333        self.debugWriter.write(msg)
0334
0335    def _executeRetry(self, conn, cursor, query):
0336        if self.debug:
0337            self.printDebug(conn, query, 'QueryR')
0338        return cursor.execute(query)
0339
0340    def _query(self, conn, s):
0341        if self.debug:
0342            self.printDebug(conn, s, 'Query')
0343        self._executeRetry(conn, conn.cursor(), s)
0344
0345    def query(self, s):
0346        return self._runWithConnection(self._query, s)
0347
0348    def _queryAll(self, conn, s):
0349        if self.debug:
0350            self.printDebug(conn, s, 'QueryAll')
0351        c = conn.cursor()
0352        self._executeRetry(conn, c, s)
0353        value = c.fetchall()
0354        if self.debugOutput:
0355            self.printDebug(conn, value, 'QueryAll', 'result')
0356        return value
0357
0358    def queryAll(self, s):
0359        return self._runWithConnection(self._queryAll, s)
0360
0361    def _queryAllDescription(self, conn, s):
0362        """
0363        Like queryAll, but returns (description, rows), where the
0364        description is cursor.description (which gives row types)
0365        """
0366        if self.debug:
0367            self.printDebug(conn, s, 'QueryAllDesc')
0368        c = conn.cursor()
0369        self._executeRetry(conn, c, s)
0370        value = c.fetchall()
0371        if self.debugOutput:
0372            self.printDebug(conn, value, 'QueryAll', 'result')
0373        return c.description, value
0374
0375    def queryAllDescription(self, s):
0376        return self._runWithConnection(self._queryAllDescription, s)
0377
0378    def _queryOne(self, conn, s):
0379        if self.debug:
0380            self.printDebug(conn, s, 'QueryOne')
0381        c = conn.cursor()
0382        self._executeRetry(conn, c, s)
0383        value = c.fetchone()
0384        if self.debugOutput:
0385            self.printDebug(conn, value, 'QueryOne', 'result')
0386        return value
0387
0388    def queryOne(self, s):
0389        return self._runWithConnection(self._queryOne, s)
0390
0391    def _insertSQL(self, table, names, values):
0392        return ("INSERT INTO %s (%s) VALUES (%s)" %
0393                (table, ', '.join(names),
0394                 ', '.join([self.sqlrepr(v) for v in values])))
0395
0396    def transaction(self):
0397        return Transaction(self)
0398
0399    def queryInsertID(self, soInstance, id, names, values):
0400        return self._runWithConnection(self._queryInsertID, soInstance, id, names, values)
0401
0402    def iterSelect(self, select):
0403        return select.IterationClass(self, self.getConnection(),
0404                         select, keepConnection=False)
0405
0406    def accumulateSelect(self, select, *expressions):
0407        """ Apply an accumulate function(s) (SUM, COUNT, MIN, AVG, MAX, etc...)
0408            to the select object.
0409        """
0410        q = select.queryForSelect().newItems(expressions).unlimited().orderBy(None)
0411        q = self.sqlrepr(q)
0412        val = self.queryOne(q)
0413        if len(expressions) == 1:
0414            val = val[0]
0415        return val
0416
0417    def queryForSelect(self, select):
0418        return self.sqlrepr(select.queryForSelect())
0419
0420    def _SO_createJoinTable(self, join):
0421        self.query(self._SO_createJoinTableSQL(join))
0422
0423    def _SO_createJoinTableSQL(self, join):
0424        return ('CREATE TABLE %s (\n%s %s,\n%s %s\n)' %
0425                (join.intermediateTable,
0426                 join.joinColumn,
0427                 self.joinSQLType(join),
0428                 join.otherColumn,
0429                 self.joinSQLType(join)))
0430
0431    def _SO_dropJoinTable(self, join):
0432        self.query("DROP TABLE %s" % join.intermediateTable)
0433
0434    def _SO_createIndex(self, soClass, index):
0435        self.query(self.createIndexSQL(soClass, index))
0436
0437    def createIndexSQL(self, soClass, index):
0438        assert 0, 'Implement in subclasses'
0439