0001import threading
0002from util.threadinglocal import local as threading_local
0003import sys
0004import re
0005import warnings
0006import atexit
0007import os
0008import new
0009import types
0010import urllib
0011import weakref
0012import inspect
0013import sqlbuilder
0014from cache import CacheSet
0015import col
0016import main
0017from joins import sorter
0018from converters import sqlrepr
0019import classregistry
0020
0021warnings.filterwarnings("ignore", "DB-API extension cursor.lastrowid used")
0022
0023_connections = {}
0024
0025def _closeConnection(ref):
0026 conn = ref()
0027 if conn is not None:
0028 conn.close()
0029
0030class ConsoleWriter:
0031 def __init__(self, loglevel):
0032 self.loglevel = loglevel
0033 self.logfile = getattr(sys, loglevel or "stdout")
0034 def write(self, text):
0035 self.logfile.write(text + '\n')
0036
0037class LogWriter:
0038 def __init__(self, logger, loglevel):
0039 self.logger = logger
0040 self.loglevel = loglevel
0041 self.logmethod = getattr(logger, loglevel)
0042 def write(self, text):
0043 self.logmethod(text)
0044
0045def makeDebugWriter(loggerName, loglevel):
0046 if not loggerName:
0047 return ConsoleWriter(loglevel)
0048 import logging
0049 logger = logging.getLogger(loggerName)
0050 return LogWriter(logger, loglevel)
0051
0052class DBConnection:
0053
0054 def __init__(self, name=None, debug=False, debugOutput=False,
0055 cache=True, style=None, autoCommit=True,
0056 debugThreading=False, registry=None,
0057 logger=None, loglevel=None):
0058 self.name = name
0059 self.debug = debug
0060 self.debugOutput = debugOutput
0061 self.debugThreading = debugThreading
0062 self.debugWriter = makeDebugWriter(logger, loglevel)
0063 self.cache = CacheSet(cache=cache)
0064 self.doCache = cache
0065 self.style = style
0066 self._connectionNumbers = {}
0067 self._connectionCount = 1
0068 self.autoCommit = autoCommit
0069 self.registry = registry or None
0070 classregistry.registry(self.registry).addCallback(
0071 self.soClassAdded)
0072 registerConnectionInstance(self)
0073 atexit.register(_closeConnection, weakref.ref(self))
0074
0075 def uri(self):
0076 auth = getattr(self, 'user', None) or ''
0077 if auth:
0078 if self.password:
0079 auth = auth + ':' + self.password
0080 auth = auth + '@'
0081 else:
0082 assert not getattr(self, 'password', None), (
0083 'URIs cannot express passwords without usernames')
0084 uri = '%s://%s' % (self.dbName, auth)
0085 if self.host:
0086 uri += self.host
0087 uri += '/'
0088 db = self.db
0089 if db.startswith('/'):
0090 db = path[1:]
0091 return uri + db
0092
0093 def isSupported(cls):
0094 raise NotImplemented
0095 isSupported = classmethod(isSupported)
0096
0097 def connectionFromURI(cls, uri):
0098 raise NotImplemented
0099 connectionFromURI = classmethod(connectionFromURI)
0100
0101 def _parseURI(uri):
0102 schema, rest = uri.split(':', 1)
0103 assert rest.startswith('/'), "URIs must start with scheme:/ -- you did not include a / (in %r)" % rest
0104 if rest.startswith('/') and not rest.startswith('//'):
0105 host = None
0106 rest = rest[1:]
0107 elif rest.startswith('///'):
0108 host = None
0109 rest = rest[3:]
0110 else:
0111 rest = rest[2:]
0112 if rest.find('/') == -1:
0113 host = rest
0114 rest = ''
0115 else:
0116 host, rest = rest.split('/', 1)
0117 if host and host.find('@') != -1:
0118 user = host[:host.rfind('@')]
0119 host = host[host.rfind('@')+1:]
0120 if user.find(':') != -1:
0121 user, password = user.split(':', 1)
0122 else:
0123 password = None
0124 else:
0125 user = password = None
0126 if host and host.find(':') != -1:
0127 _host, port = host.split(':')
0128 try:
0129 port = int(port)
0130 except ValueError:
0131 raise ValueError, "port must be integer, got '%s' instead" % port
0132 if not (1 <= port <= 65535):
0133 raise ValueError, "port must be integer in the range 1-65535, got '%d' instead" % port
0134 host = _host
0135 else:
0136 port = None
0137 path = '/' + rest
0138 if os.name == 'nt':
0139 if (len(rest) > 1) and (rest[1] == '|'):
0140 path = "%s:%s" % (rest[0], rest[2:])
0141 args = {}
0142 if path.find('?') != -1:
0143 path, arglist = path.split('?', 1)
0144 arglist = arglist.split('&')
0145 for single in arglist:
0146 argname, argvalue = single.split('=', 1)
0147 argvalue = urllib.unquote(argvalue)
0148 args[argname] = argvalue
0149 return user, password, host, port, path, args
0150 _parseURI = staticmethod(_parseURI)
0151
0152 def soClassAdded(self, soClass):
0153 """
0154 This is called for each new class; we use this opportunity
0155 to create an instance method that is bound to the class
0156 and this connection.
0157 """
0158 name = soClass.__name__
0159 assert not hasattr(self, name), (
0160 "Connection %r already has an attribute with the name "
0161 "%r (and you just created the conflicting class %r)"
0162 % (self, name, soClass))
0163 setattr(self, name, ConnWrapper(soClass, self))
0164
0165 def expireAll(self):
0166 """
0167 Expire all instances of objects for this connection.
0168 """
0169 cache_set = self.cache
0170 cache_set.weakrefAll()
0171 for item in cache_set.getAll():
0172 item.expire()
0173
0174class ConnWrapper(object):
0175
0176 """
0177 This represents a SQLObject class that is bound to a specific
0178 connection (instances have a connection instance variable, but
0179 classes are global, so this is binds the connection variable
0180 lazily when a class method is accessed)
0181 """
0182
0183
0184
0185
0186 def __init__(self, soClass, connection):
0187 self._soClass = soClass
0188 self._connection = connection
0189
0190 def __call__(self, *args, **kw):
0191 kw['connection'] = self._connection
0192 return self._soClass(*args, **kw)
0193
0194 def __getattr__(self, attr):
0195 meth = getattr(self._soClass, attr)
0196 if not isinstance(meth, types.MethodType):
0197
0198 return meth
0199 try:
0200 takes_conn = meth.takes_connection
0201 except AttributeError:
0202 args, varargs, varkw, defaults = inspect.getargspec(meth)
0203 assert not varkw and not varargs, (
0204 "I cannot tell whether I must wrap this method, "
0205 "because it takes **kw: %r"
0206 % meth)
0207 takes_conn = 'connection' in args
0208 meth.im_func.takes_connection = takes_conn
0209 if not takes_conn:
0210 return meth
0211 return ConnMethodWrapper(meth, self._connection)
0212
0213class ConnMethodWrapper(object):
0214
0215 def __init__(self, method, connection):
0216 self._method = method
0217 self._connection = connection
0218
0219 def __getattr__(self, attr):
0220 return getattr(self._method, attr)
0221
0222 def __call__(self, *args, **kw):
0223 kw['connection'] = self._connection
0224 return self._method(*args, **kw)
0225
0226 def __repr__(self):
0227 return '<Wrapped %r with connection %r>' % (
0228 self._method, self._connection)
0229
0230class DBAPI(DBConnection):
0231
0232 """
0233 Subclass must define a `makeConnection()` method, which
0234 returns a newly-created connection object.
0235
0236 ``queryInsertID`` must also be defined.
0237 """
0238
0239 dbName = None
0240
0241 def __init__(self, **kw):
0242 self._pool = []
0243 self._poolLock = threading.Lock()
0244 DBConnection.__init__(self, **kw)
0245 self._binaryType = type(self.module.Binary(''))
0246
0247 def _runWithConnection(self, meth, *args):
0248 conn = self.getConnection()
0249 try:
0250 val = meth(conn, *args)
0251 finally:
0252 self.releaseConnection(conn)
0253 return val
0254
0255 def getConnection(self):
0256 self._poolLock.acquire()
0257 try:
0258 if not self._pool:
0259 conn = self.makeConnection()
0260 self._connectionNumbers[id(conn)] = self._connectionCount
0261 self._connectionCount += 1
0262 else:
0263 conn = self._pool.pop()
0264 if self.debug:
0265 s = 'ACQUIRE'
0266 if self._pool is not None:
0267 s += ' pool=[%s]' % ', '.join([str(self._connectionNumbers[id(v)]) for v in self._pool])
0268 self.printDebug(conn, s, 'Pool')
0269 return conn
0270 finally:
0271 self._poolLock.release()
0272
0273 def releaseConnection(self, conn, explicit=False):
0274 if self.debug:
0275 if explicit:
0276 s = 'RELEASE (explicit)'
0277 else:
0278 s = 'RELEASE (implicit, autocommit=%s)' % self.autoCommit
0279 if self._pool is None:
0280 s += ' no pooling'
0281 else:
0282 s += ' pool=[%s]' % ', '.join([str(self._connectionNumbers[id(v)]) for v in self._pool])
0283 self.printDebug(conn, s, 'Pool')
0284 if self.supportTransactions and not explicit:
0285 if self.autoCommit == 'exception':
0286 if self.debug:
0287 self.printDebug(conn, 'auto/exception', 'ROLLBACK')
0288 conn.rollback()
0289 raise Exception, 'Object used outside of a transaction; implicit COMMIT or ROLLBACK not allowed'
0290 elif self.autoCommit:
0291 if self.debug:
0292 self.printDebug(conn, 'auto', 'COMMIT')
0293 if not getattr(conn, 'autocommit', False):
0294 conn.commit()
0295 else:
0296 if self.debug:
0297 self.printDebug(conn, 'auto', 'ROLLBACK')
0298 conn.rollback()
0299 if self._pool is not None:
0300 if conn not in self._pool:
0301
0302
0303
0304 self._pool.insert(0, conn)
0305 else:
0306 conn.close()
0307
0308 def printDebug(self, conn, s, name, type='query'):
0309 if name == 'Pool' and self.debug != 'Pool':
0310 return
0311 if type == 'query':
0312 sep = ': '
0313 else:
0314 sep = '->'
0315 s = repr(s)
0316 n = self._connectionNumbers[id(conn)]
0317 spaces = ' '*(8-len(name))
0318 if self.debugThreading:
0319 threadName = threading.currentThread().getName()
0320 threadName = (':' + threadName + ' '*(8-len(threadName)))
0321 else:
0322 threadName = ''
0323 msg = '%(n)2i%(threadName)s/%(name)s%(spaces)s%(sep)s %(s)s' % locals()
0324 self.debugWriter.write(msg)
0325
0326 def _executeRetry(self, conn, cursor, query):
0327 if self.debug:
0328 self.printDebug(conn, query, 'QueryR')
0329 return cursor.execute(query)
0330
0331 def _query(self, conn, s):
0332 if self.debug:
0333 self.printDebug(conn, s, 'Query')
0334 self._executeRetry(conn, conn.cursor(), s)
0335
0336 def query(self, s):
0337 return self._runWithConnection(self._query, s)
0338
0339 def _queryAll(self, conn, s):
0340 if self.debug:
0341 self.printDebug(conn, s, 'QueryAll')
0342 c = conn.cursor()
0343 self._executeRetry(conn, c, s)
0344 value = c.fetchall()
0345 if self.debugOutput:
0346 self.printDebug(conn, value, 'QueryAll', 'result')
0347 return value
0348
0349 def queryAll(self, s):
0350 return self._runWithConnection(self._queryAll, s)
0351
0352 def _queryAllDescription(self, conn, s):
0353 """
0354 Like queryAll, but returns (description, rows), where the
0355 description is cursor.description (which gives row types)
0356 """
0357 if self.debug:
0358 self.printDebug(conn, s, 'QueryAllDesc')
0359 c = conn.cursor()
0360 self._executeRetry(conn, c, s)
0361 value = c.fetchall()
0362 if self.debugOutput:
0363 self.printDebug(conn, value, 'QueryAll', 'result')
0364 return c.description, value
0365
0366 def queryAllDescription(self, s):
0367 return self._runWithConnection(self._queryAllDescription, s)
0368
0369 def _queryOne(self, conn, s):
0370 if self.debug:
0371 self.printDebug(conn, s, 'QueryOne')
0372 c = conn.cursor()
0373 self._executeRetry(conn, c, s)
0374 value = c.fetchone()
0375 if self.debugOutput:
0376 self.printDebug(conn, value, 'QueryOne', 'result')
0377 return value
0378
0379 def queryOne(self, s):
0380 return self._runWithConnection(self._queryOne, s)
0381
0382 def _insertSQL(self, table, names, values):
0383 return ("INSERT INTO %s (%s) VALUES (%s)" %
0384 (table, ', '.join(names),
0385 ', '.join([self.sqlrepr(v) for v in values])))
0386
0387 def transaction(self):
0388 return Transaction(self)
0389
0390 def queryInsertID(self, soInstance, id, names, values):
0391 return self._runWithConnection(self._queryInsertID, soInstance, id, names, values)
0392
0393 def iterSelect(self, select):
0394 return select.IterationClass(self, self.getConnection(),
0395 select, keepConnection=False)
0396
0397 def accumulateSelect(self, select, *expressions):
0398 """ Apply an accumulate function(s) (SUM, COUNT, MIN, AVG, MAX, etc...)
0399 to the select object.
0400 """
0401 q = select.queryForSelect().newItems(expressions).unlimited().orderBy(None)
0402 q = self.sqlrepr(q)
0403 val = self.queryOne(q)
0404 if len(expressions) == 1:
0405 val = val[0]
0406 return val
0407
0408 def queryForSelect(self, select):
0409 return self.sqlrepr(select.queryForSelect())
0410
0411 def _SO_createJoinTable(self, join):
0412 self.query(self._SO_createJoinTableSQL(join))
0413
0414 def _SO_createJoinTableSQL(self, join):
0415 return ('CREATE TABLE %s (\n%s %s,\n%s %s\n)' %
0416 (join.intermediateTable,
0417 join.joinColumn,
0418 self.joinSQLType(join),
0419 join.otherColumn,
0420 self.joinSQLType(join)))
0421
0422 def _SO_dropJoinTable(self, join):
0423 self.query("DROP TABLE %s" % join.intermediateTable)
0424
0425 def _SO_createIndex(self, soClass, index):
0426 self.query(self.createIndexSQL(soClass, index))
0427
0428 def createIndexSQL(self, soClass, index):
0429 assert 0, 'Implement in subclasses'
0430
0431 def createTable(self, soClass):
0432 createSql, constraints = self.createTableSQL(soClass)
0433 self.query(createSql)
0434
0435 return constraints
0436
0437 def createReferenceConstraints(self, soClass):
0438 refConstraints = [self.createReferenceConstraint(soClass, column) for column in soClass.sqlmeta.columnList if isinstance(column, col