0001import threading
0002from util.threadinglocal import local as threading_local
0003import sys
0004import re
0005import warnings
0006import atexit
0007import os
0008import new
0009import types
0010import urllib
0011import weakref
0012import inspect
0013import sqlbuilder
0014from cache import CacheSet
0015import col
0016import main
0017from joins import sorter
0018from converters import sqlrepr
0019import classregistry
0020
0021warnings.filterwarnings("ignore", "DB-API extension cursor.lastrowid used")
0022
0023_connections = {}
0024
0025def _closeConnection(ref):
0026 conn = ref()
0027 if conn is not None:
0028 conn.close()
0029
0030class ConsoleWriter:
0031 def __init__(self, loglevel):
0032 self.loglevel = loglevel
0033 self.logfile = getattr(sys, loglevel or "stdout")
0034 def write(self, text):
0035 self.logfile.write(text + '\n')
0036
0037class LogWriter:
0038 def __init__(self, logger, loglevel):
0039 self.logger = logger
0040 self.loglevel = loglevel
0041 self.logmethod = getattr(logger, loglevel)
0042 def write(self, text):
0043 self.logmethod(text)
0044
0045def makeDebugWriter(loggerName, loglevel):
0046 if not loggerName:
0047 return ConsoleWriter(loglevel)
0048 import logging
0049 logger = logging.getLogger(loggerName)
0050 return LogWriter(logger, loglevel)
0051
0052class Boolean(object):
0053 """A bool class that also understands some special string keywords (yes/no, true/false, on/off, 1/0)"""
0054 _keywords = {'1': True, 'yes': True, 'true': True, 'on': True,
0055 '0': False, 'no': False, 'false': False, 'off': False}
0056 def __new__(cls, value):
0057 try:
0058 return Boolean._keywords[value.lower()]
0059 except (AttributeError, KeyError):
0060 return bool(value)
0061
0062class DBConnection:
0063
0064 def __init__(self, name=None, debug=False, debugOutput=False,
0065 cache=True, style=None, autoCommit=True,
0066 debugThreading=False, registry=None,
0067 logger=None, loglevel=None):
0068 self.name = name
0069 self.debug = Boolean(debug)
0070 self.debugOutput = Boolean(debugOutput)
0071 self.debugThreading = Boolean(debugThreading)
0072 self.debugWriter = makeDebugWriter(logger, loglevel)
0073 self.doCache = Boolean(cache)
0074 self.cache = CacheSet(cache=self.doCache)
0075 self.style = style
0076 self._connectionNumbers = {}
0077 self._connectionCount = 1
0078 self.autoCommit = Boolean(autoCommit)
0079 self.registry = registry or None
0080 classregistry.registry(self.registry).addCallback(
0081 self.soClassAdded)
0082 registerConnectionInstance(self)
0083 atexit.register(_closeConnection, weakref.ref(self))
0084
0085 def uri(self):
0086 auth = getattr(self, 'user', '')
0087 if auth:
0088 if self.password:
0089 auth = auth + ':' + self.password
0090 auth = auth + '@'
0091 else:
0092 assert not getattr(self, 'password', None), (
0093 'URIs cannot express passwords without usernames')
0094 uri = '%s://%s' % (self.dbName, auth)
0095 if self.host:
0096 uri += self.host
0097 uri += '/'
0098 db = self.db
0099 if db.startswith('/'):
0100 db = path[1:]
0101 return uri + db
0102
0103 def isSupported(cls):
0104 raise NotImplemented
0105 isSupported = classmethod(isSupported)
0106
0107 def connectionFromURI(cls, uri):
0108 raise NotImplemented
0109 connectionFromURI = classmethod(connectionFromURI)
0110
0111 def _parseURI(uri):
0112 schema, rest = uri.split(':', 1)
0113 assert rest.startswith('/'), "URIs must start with scheme:/ -- you did not include a / (in %r)" % rest
0114 if rest.startswith('/') and not rest.startswith('//'):
0115 host = None
0116 rest = rest[1:]
0117 elif rest.startswith('///'):
0118 host = None
0119 rest = rest[3:]
0120 else:
0121 rest = rest[2:]
0122 if rest.find('/') == -1:
0123 host = rest
0124 rest = ''
0125 else:
0126 host, rest = rest.split('/', 1)
0127 if host and host.find('@') != -1:
0128 user, host = host.rsplit('@', 1)
0129 if user.find(':') != -1:
0130 user, password = user.split(':', 1)
0131 else:
0132 password = None
0133 else:
0134 user = password = None
0135 if host and host.find(':') != -1:
0136 _host, port = host.split(':')
0137 try:
0138 port = int(port)
0139 except ValueError:
0140 raise ValueError, "port must be integer, got '%s' instead" % port
0141 if not (1 <= port <= 65535):
0142 raise ValueError, "port must be integer in the range 1-65535, got '%d' instead" % port
0143 host = _host
0144 else:
0145 port = None
0146 path = '/' + rest
0147 if os.name == 'nt':
0148 if (len(rest) > 1) and (rest[1] == '|'):
0149 path = "%s:%s" % (rest[0], rest[2:])
0150 args = {}
0151 if path.find('?') != -1:
0152 path, arglist = path.split('?', 1)
0153 arglist = arglist.split('&')
0154 for single in arglist:
0155 argname, argvalue = single.split('=', 1)
0156 argvalue = urllib.unquote(argvalue)
0157 args[argname] = argvalue
0158 return user, password, host, port, path, args
0159 _parseURI = staticmethod(_parseURI)
0160
0161 def soClassAdded(self, soClass):
0162 """
0163 This is called for each new class; we use this opportunity
0164 to create an instance method that is bound to the class
0165 and this connection.
0166 """
0167 name = soClass.__name__
0168 assert not hasattr(self, name), (
0169 "Connection %r already has an attribute with the name "
0170 "%r (and you just created the conflicting class %r)"
0171 % (self, name, soClass))
0172 setattr(self, name, ConnWrapper(soClass, self))
0173
0174 def expireAll(self):
0175 """
0176 Expire all instances of objects for this connection.
0177 """
0178 cache_set = self.cache
0179 cache_set.weakrefAll()
0180 for item in cache_set.getAll():
0181 item.expire()
0182
0183class ConnWrapper(object):
0184
0185 """
0186 This represents a SQLObject class that is bound to a specific
0187 connection (instances have a connection instance variable, but
0188 classes are global, so this is binds the connection variable
0189 lazily when a class method is accessed)
0190 """
0191
0192
0193
0194
0195 def __init__(self, soClass, connection):
0196 self._soClass = soClass
0197 self._connection = connection
0198
0199 def __call__(self, *args, **kw):
0200 kw['connection'] = self._connection
0201 return self._soClass(*args, **kw)
0202
0203 def __getattr__(self, attr):
0204 meth = getattr(self._soClass, attr)
0205 if not isinstance(meth, types.MethodType):
0206
0207 return meth
0208 try:
0209 takes_conn = meth.takes_connection
0210 except AttributeError:
0211 args, varargs, varkw, defaults = inspect.getargspec(meth)
0212 assert not varkw and not varargs, (
0213 "I cannot tell whether I must wrap this method, "
0214 "because it takes **kw: %r"
0215 % meth)
0216 takes_conn = 'connection' in args
0217 meth.im_func.takes_connection = takes_conn
0218 if not takes_conn:
0219 return meth
0220 return ConnMethodWrapper(meth, self._connection)
0221
0222class ConnMethodWrapper(object):
0223
0224 def __init__(self, method, connection):
0225 self._method = method
0226 self._connection = connection
0227
0228 def __getattr__(self, attr):
0229 return getattr(self._method, attr)
0230
0231 def __call__(self, *args, **kw):
0232 kw['connection'] = self._connection
0233 return self._method(*args, **kw)
0234
0235 def __repr__(self):
0236 return '<Wrapped %r with connection %r>' % (
0237 self._method, self._connection)
0238
0239class DBAPI(DBConnection):
0240
0241 """
0242 Subclass must define a `makeConnection()` method, which
0243 returns a newly-created connection object.
0244
0245 ``queryInsertID`` must also be defined.
0246 """
0247
0248 dbName = None
0249
0250 def __init__(self, **kw):
0251 self._pool = []
0252 self._poolLock = threading.Lock()
0253 DBConnection.__init__(self, **kw)
0254 self._binaryType = type(self.module.Binary(''))
0255
0256 def _runWithConnection(self, meth, *args):
0257 conn = self.getConnection()
0258 try:
0259 val = meth(conn, *args)
0260 finally:
0261 self.releaseConnection(conn)
0262 return val
0263
0264 def getConnection(self):
0265 self._poolLock.acquire()
0266 try:
0267 if not self._pool:
0268 conn = self.makeConnection()
0269 self._connectionNumbers[id(conn)] = self._connectionCount
0270 self._connectionCount += 1
0271 else:
0272 conn = self._pool.pop()
0273 if self.debug:
0274 s = 'ACQUIRE'
0275 if self._pool is not None:
0276 s += ' pool=[%s]' % ', '.join([str(self._connectionNumbers[id(v)]) for v in self._pool])
0277 self.printDebug(conn, s, 'Pool')
0278 return conn
0279 finally:
0280 self._poolLock.release()
0281
0282 def releaseConnection(self, conn, explicit=False):
0283 if self.debug:
0284 if explicit:
0285 s = 'RELEASE (explicit)'
0286 else:
0287 s = 'RELEASE (implicit, autocommit=%s)' % self.autoCommit
0288 if self._pool is None:
0289 s += ' no pooling'
0290 else:
0291 s += ' pool=[%s]' % ', '.join([str(self._connectionNumbers[id(v)]) for v in self._pool])
0292 self.printDebug(conn, s, 'Pool')
0293 if self.supportTransactions and not explicit:
0294 if self.autoCommit == 'exception':
0295 if self.debug:
0296 self.printDebug(conn, 'auto/exception', 'ROLLBACK')
0297 conn.rollback()
0298 raise Exception, 'Object used outside of a transaction; implicit COMMIT or ROLLBACK not allowed'
0299 elif self.autoCommit:
0300 if self.debug:
0301 self.printDebug(conn, 'auto', 'COMMIT')
0302 if not getattr(conn, 'autocommit', False):
0303 conn.commit()
0304 else:
0305 if self.debug:
0306 self.printDebug(conn, 'auto', 'ROLLBACK')
0307 conn.rollback()
0308 if self._pool is not None:
0309 if conn not in self._pool:
0310
0311
0312
0313 self._pool.insert(0, conn)
0314 else:
0315 conn.close()
0316
0317 def printDebug(self, conn, s, name, type='query'):
0318 if name == 'Pool' and self.debug != 'Pool':
0319 return
0320 if type == 'query':
0321 sep = ': '
0322 else:
0323 sep = '->'
0324 s = repr(s)
0325 n = self._connectionNumbers[id(conn)]
0326 spaces = ' '*(8-len(name))
0327 if self.debugThreading:
0328 threadName = threading.currentThread().getName()
0329 threadName = (':' + threadName + ' '*(8-len(threadName)))
0330 else:
0331 threadName = ''
0332 msg = '%(n)2i%(threadName)s/%(name)s%(spaces)s%(sep)s %(s)s' % locals()
0333 self.debugWriter.write(msg)
0334
0335 def _executeRetry(self, conn, cursor, query):
0336 if self.debug:
0337 self.printDebug(conn, query, 'QueryR')
0338 return cursor.execute(query)
0339
0340 def _query(self, conn, s):
0341 if self.debug:
0342 self.printDebug(conn, s, 'Query')
0343 self._executeRetry(conn, conn.cursor(), s)
0344
0345 def query(self, s):
0346 return self._runWithConnection(self._query, s)
0347
0348 def _queryAll(self, conn, s):
0349 if self.debug:
0350 self.printDebug(conn, s, 'QueryAll')
0351 c = conn.cursor()
0352 self._executeRetry(conn, c, s)
0353 value = c.fetchall()
0354 if self.debugOutput:
0355 self.printDebug(conn, value, 'QueryAll', 'result')
0356 return value
0357
0358 def queryAll(self, s):
0359 return self._runWithConnection(self._queryAll, s)
0360
0361 def _queryAllDescription(self, conn, s):
0362 """
0363 Like queryAll, but returns (description, rows), where the
0364 description is cursor.description (which gives row types)
0365 """
0366 if self.debug:
0367 self.printDebug(conn, s, 'QueryAllDesc')
0368 c = conn.cursor()
0369 self._executeRetry(conn, c, s)
0370 value = c.fetchall()
0371 if self.debugOutput:
0372 self.printDebug(conn, value, 'QueryAll', 'result')
0373 return c.description, value
0374
0375 def queryAllDescription(self, s):
0376 return self._runWithConnection(self._queryAllDescription, s)
0377
0378 def _queryOne(self, conn, s):
0379 if self.debug:
0380 self.printDebug(conn, s, 'QueryOne')
0381 c = conn.cursor()
0382 self._executeRetry(conn, c, s)
0383 value = c.fetchone()
0384 if self.debugOutput:
0385 self.printDebug(conn, value, 'QueryOne', 'result')
0386 return value
0387
0388 def queryOne(self, s):
0389 return self._runWithConnection(self._queryOne, s)
0390
0391 def _insertSQL(self, table, names, values):
0392 return ("INSERT INTO %s (%s) VALUES (%s)" %
0393 (table, ', '.join(names),
0394 ', '.join([self.sqlrepr(v) for v in values])))
0395
0396 def transaction(self):
0397 return Transaction(self)
0398
0399 def queryInsertID(self, soInstance, id, names, values):
0400 return self._runWithConnection(self._queryInsertID, soInstance, id, names, values)
0401
0402 def iterSelect(self, select):
0403 return select.IterationClass(self, self.getConnection(),
0404 select, keepConnection=False)
0405
0406 def accumulateSelect(self, select, *expressions):
0407 """ Apply an accumulate function(s) (SUM, COUNT, MIN, AVG, MAX, etc...)
0408 to the select object.
0409 """
0410 q = select.queryForSelect().newItems(expressions).unlimited().orderBy(None)
0411 q = self.sqlrepr(q)
0412 val = self.queryOne(q)
0413 if len(expressions) == 1:
0414 val = val[0]
0415 return val
0416
0417 def queryForSelect(self, select):
0418 return self.sqlrepr(select.queryForSelect())
0419
0420 def _SO_createJoinTable(self, join):
0421 self.query(self._SO_createJoinTableSQL(join))
0422
0423 def _SO_createJoinTableSQL(self, join):
0424 return ('CREATE TABLE %s (\n%s %s,\n%s %s\n)' %
0425 (join.intermediateTable,
0426 join.joinColumn,
0427 self.joinSQLType(join),
0428 join.otherColumn,
0429 self.joinSQLType(join)))
0430
0431 def _SO_dropJoinTable(self, join):
0432 self.query("DROP TABLE %s" % join.intermediateTable)
0433
0434 def _SO_createIndex(self, soClass, index):
0435 self.query(self.createIndexSQL(soClass, index))
0436
0437 def createIndexSQL(self, soClass, index):
0438 assert 0, 'Implement in subclasses'
0439