0001import atexit
0002from cgi import parse_qsl
0003import inspect
0004import new
0005import os
0006import sys
0007import threading
0008import types
0009import urllib
0010import warnings
0011import weakref
0012
0013from cache import CacheSet
0014import classregistry
0015import col
0016from converters import sqlrepr
0017import main
0018import sqlbuilder
0019from util.threadinglocal import local as threading_local
0020
0021warnings.filterwarnings("ignore", "DB-API extension cursor.lastrowid used")
0022
0023_connections = {}
0024
0025def _closeConnection(ref):
0026 conn = ref()
0027 if conn is not None:
0028 conn.close()
0029
0030class ConsoleWriter:
0031 def __init__(self, connection, loglevel):
0032
0033 self.loglevel = loglevel or "stdout"
0034 self.dbEncoding = getattr(connection, "dbEncoding", None) or "ascii"
0035 def write(self, text):
0036 logfile = getattr(sys, self.loglevel)
0037 if isinstance(text, unicode):
0038 try:
0039 text = text.encode(self.dbEncoding)
0040 except UnicodeEncodeError:
0041 text = repr(text)[2:-1]
0042 logfile.write(text + '\n')
0043
0044class LogWriter:
0045 def __init__(self, connection, logger, loglevel):
0046 self.logger = logger
0047 self.loglevel = loglevel
0048 self.logmethod = getattr(logger, loglevel)
0049 def write(self, text):
0050 self.logmethod(text)
0051
0052def makeDebugWriter(connection, loggerName, loglevel):
0053 if not loggerName:
0054 return ConsoleWriter(connection, loglevel)
0055 import logging
0056 logger = logging.getLogger(loggerName)
0057 return LogWriter(connection, logger, loglevel)
0058
0059class Boolean(object):
0060 """A bool class that also understands some special string keywords (yes/no, true/false, on/off, 1/0)"""
0061 _keywords = {'1': True, 'yes': True, 'true': True, 'on': True,
0062 '0': False, 'no': False, 'false': False, 'off': False}
0063 def __new__(cls, value):
0064 try:
0065 return Boolean._keywords[value.lower()]
0066 except (AttributeError, KeyError):
0067 return bool(value)
0068
0069class DBConnection:
0070
0071 def __init__(self, name=None, debug=False, debugOutput=False,
0072 cache=True, style=None, autoCommit=True,
0073 debugThreading=False, registry=None,
0074 logger=None, loglevel=None):
0075 self.name = name
0076 self.debug = Boolean(debug)
0077 self.debugOutput = Boolean(debugOutput)
0078 self.debugThreading = Boolean(debugThreading)
0079 self.debugWriter = makeDebugWriter(self, logger, loglevel)
0080 self.doCache = Boolean(cache)
0081 self.cache = CacheSet(cache=self.doCache)
0082 self.style = style
0083 self._connectionNumbers = {}
0084 self._connectionCount = 1
0085 self.autoCommit = Boolean(autoCommit)
0086 self.registry = registry or None
0087 classregistry.registry(self.registry).addCallback(self.soClassAdded)
0088 registerConnectionInstance(self)
0089 atexit.register(_closeConnection, weakref.ref(self))
0090
0091 def oldUri(self):
0092 auth = getattr(self, 'user', '') or ''
0093 if auth:
0094 if self.password:
0095 auth = auth + ':' + self.password
0096 auth = auth + '@'
0097 else:
0098 assert not getattr(self, 'password', None), (
0099 'URIs cannot express passwords without usernames')
0100 uri = '%s://%s' % (self.dbName, auth)
0101 if self.host:
0102 uri += self.host
0103 if self.port:
0104 uri += ':%d' % self.port
0105 uri += '/'
0106 db = self.db
0107 if db.startswith('/'):
0108 db = db[1:]
0109 return uri + db
0110
0111 def uri(self):
0112 auth = getattr(self, 'user', '') or ''
0113 if auth:
0114 auth = urllib.quote(auth)
0115 if self.password:
0116 auth = auth + ':' + urllib.quote(self.password)
0117 auth = auth + '@'
0118 else:
0119 assert not getattr(self, 'password', None), (
0120 'URIs cannot express passwords without usernames')
0121 uri = '%s://%s' % (self.dbName, auth)
0122 if self.host:
0123 uri += self.host
0124 if self.port:
0125 uri += ':%d' % self.port
0126 uri += '/'
0127 db = self.db
0128 if db.startswith('/'):
0129 db = db[1:]
0130 return uri + urllib.quote(db)
0131
0132 @classmethod
0133 def connectionFromOldURI(cls, uri):
0134 return cls._connectionFromParams(*cls._parseOldURI(uri))
0135
0136 @classmethod
0137 def connectionFromURI(cls, uri):
0138 return cls._connectionFromParams(*cls._parseURI(uri))
0139
0140 @staticmethod
0141 def _parseOldURI(uri):
0142 schema, rest = uri.split(':', 1)
0143 assert rest.startswith('/'), "URIs must start with scheme:/ -- you did not include a / (in %r)" % rest
0144 if rest.startswith('/') and not rest.startswith('//'):
0145 host = None
0146 rest = rest[1:]
0147 elif rest.startswith('///'):
0148 host = None
0149 rest = rest[3:]
0150 else:
0151 rest = rest[2:]
0152 if rest.find('/') == -1:
0153 host = rest
0154 rest = ''
0155 else:
0156 host, rest = rest.split('/', 1)
0157 if host and host.find('@') != -1:
0158 user, host = host.rsplit('@', 1)
0159 if user.find(':') != -1:
0160 user, password = user.split(':', 1)
0161 else:
0162 password = None
0163 else:
0164 user = password = None
0165 if host and host.find(':') != -1:
0166 _host, port = host.split(':')
0167 try:
0168 port = int(port)
0169 except ValueError:
0170 raise ValueError, "port must be integer, got '%s' instead" % port
0171 if not (1 <= port <= 65535):
0172 raise ValueError, "port must be integer in the range 1-65535, got '%d' instead" % port
0173 host = _host
0174 else:
0175 port = None
0176 path = '/' + rest
0177 if os.name == 'nt':
0178 if (len(rest) > 1) and (rest[1] == '|'):
0179 path = "%s:%s" % (rest[0], rest[2:])
0180 args = {}
0181 if path.find('?') != -1:
0182 path, arglist = path.split('?', 1)
0183 arglist = arglist.split('&')
0184 for single in arglist:
0185 argname, argvalue = single.split('=', 1)
0186 argvalue = urllib.unquote(argvalue)
0187 args[argname] = argvalue
0188 return user, password, host, port, path, args
0189
0190 @staticmethod
0191 def _parseURI(uri):
0192 protocol, request = urllib.splittype(uri)
0193 user, password, port = None, None, None
0194 host, path = urllib.splithost(request)
0195
0196 if host:
0197
0198
0199 if '@' in host:
0200 user, host = host.split('@', 1)
0201 if user:
0202 user, password = [x and urllib.unquote(x) or None for x in urllib.splitpasswd(user)]
0203 host, port = urllib.splitport(host)
0204 if port: port = int(port)
0205 elif host == '':
0206 host = None
0207
0208
0209 path, tag = urllib.splittag(path)
0210 path, query = urllib.splitquery(path)
0211
0212 path = urllib.unquote(path)
0213 if (os.name == 'nt') and (len(path) > 2):
0214
0215
0216 if path[2] == '|':
0217 path = "%s:%s" % (path[0:2], path[3:])
0218
0219 if (path[0] == '/') and (path[2] == ':'):
0220 path = path[1:]
0221
0222 args = {}
0223 if query:
0224 for name, value in parse_qsl(query):
0225 args[name] = value
0226
0227 return user, password, host, port, path, args
0228
0229 def soClassAdded(self, soClass):
0230 """
0231 This is called for each new class; we use this opportunity
0232 to create an instance method that is bound to the class
0233 and this connection.
0234 """
0235 name = soClass.__name__
0236 assert not hasattr(self, name), (
0237 "Connection %r already has an attribute with the name "
0238 "%r (and you just created the conflicting class %r)"
0239 % (self, name, soClass))
0240 setattr(self, name, ConnWrapper(soClass, self))
0241
0242 def expireAll(self):
0243 """
0244 Expire all instances of objects for this connection.
0245 """
0246 cache_set = self.cache
0247 cache_set.weakrefAll()
0248 for item in cache_set.getAll():
0249 item.expire()
0250
0251class ConnWrapper(object):
0252
0253 """
0254 This represents a SQLObject class that is bound to a specific
0255 connection (instances have a connection instance variable, but
0256 classes are global, so this is binds the connection variable
0257 lazily when a class method is accessed)
0258 """
0259
0260
0261
0262
0263 def __init__(self, soClass, connection):
0264 self._soClass = soClass
0265 self._connection = connection
0266
0267 def __call__(self, *args, **kw):
0268 kw['connection'] = self._connection
0269 return self._soClass(*args, **kw)
0270
0271 def __getattr__(self, attr):
0272 meth = getattr(self._soClass, attr)
0273 if not isinstance(meth, types.MethodType):
0274
0275 return meth
0276 try:
0277 takes_conn = meth.takes_connection
0278 except AttributeError:
0279 args, varargs, varkw, defaults = inspect.getargspec(meth)
0280 assert not varkw and not varargs, (
0281 "I cannot tell whether I must wrap this method, "
0282 "because it takes **kw: %r"
0283 % meth)
0284 takes_conn = 'connection' in args
0285 meth.im_func.takes_connection = takes_conn
0286 if not takes_conn:
0287 return meth
0288 return ConnMethodWrapper(meth, self._connection)
0289
0290class ConnMethodWrapper(object):
0291
0292 def __init__(self, method, connection):
0293 self._method = method
0294 self._connection = connection
0295
0296 def __getattr__(self, attr):
0297 return getattr(self._method, attr)
0298
0299 def __call__(self, *args, **kw):
0300 kw['connection'] = self._connection
0301 return self._method(*args, **kw)
0302
0303 def __repr__(self):
0304 return '<Wrapped %r with connection %r>' % (
0305 self._method, self._connection)
0306
0307class DBAPI(DBConnection):
0308
0309 """
0310 Subclass must define a `makeConnection()` method, which
0311 returns a newly-created connection object.
0312
0313 ``queryInsertID`` must also be defined.
0314 """
0315
0316 dbName = None
0317
0318 def __init__(self, **kw):
0319 self._pool = []
0320 self._poolLock = threading.Lock()
0321 DBConnection.__init__(self, **kw)
0322 self._binaryType = type(self.module.Binary(''))
0323
0324 def _runWithConnection(self, meth, *args):
0325 conn = self.getConnection()
0326 try:
0327 val = meth(conn, *args)
0328 finally:
0329 self.releaseConnection(conn)
0330 return val
0331
0332 def getConnection(self):
0333 self._poolLock.acquire()
0334 try:
0335 if not self._pool:
0336 conn = self.makeConnection()
0337 self._connectionNumbers[id(conn)] = self._connectionCount
0338 self._connectionCount += 1
0339 else:
0340 conn = self._pool.pop()
0341 if self.debug:
0342 s = 'ACQUIRE'
0343 if self._pool is not None:
0344 s += ' pool=[%s]' % ', '.join([str(self._connectionNumbers[id(v)]) for v in self._pool])
0345 self.printDebug(conn, s, 'Pool')
0346 return conn
0347 finally:
0348 self._poolLock.release()
0349
0350 def releaseConnection(self, conn, explicit=False):
0351 if self.debug:
0352 if explicit:
0353 s = 'RELEASE (explicit)'
0354 else:
0355 s = 'RELEASE (implicit, autocommit=%s)' % self.autoCommit
0356 if self._pool is None:
0357 s += ' no pooling'
0358 else:
0359 s += ' pool=[%s]' % ', '.join([str(self._connectionNumbers[id(v)]) for v in self._pool])
0360 self.printDebug(conn, s, 'Pool')
0361 if self.supportTransactions and not explicit:
0362 if self.autoCommit == 'exception':
0363 if self.debug:
0364 self.printDebug(conn, 'auto/exception', 'ROLLBACK')
0365 conn.rollback()
0366 raise Exception, 'Object used outside of a transaction; implicit COMMIT or ROLLBACK not allowed'
0367 elif self.autoCommit:
0368 if self.debug:
0369 self.printDebug(conn, 'auto', 'COMMIT')
0370 if not getattr(conn, 'autocommit', False):
0371 conn.commit()
0372 else:
0373 if self.debug:
0374 self.printDebug(conn, 'auto', 'ROLLBACK')
0375 conn.rollback()
0376 if self._pool is not None:
0377 if conn not in self._pool:
0378
0379
0380
0381 self._pool.insert(0, conn)
0382 else:
0383 conn.close()
0384
0385 def printDebug(self, conn, s, name, type='query'):
0386 if name == 'Pool' and self.debug != 'Pool':
0387 return
0388 if type == 'query':
0389 sep = ': '
0390 else:
0391 sep = '->'
0392 s = repr(s)
0393 n = self._connectionNumbers[id(conn)]
0394 spaces = ' '*(8-len(name))
0395 if self.debugThreading:
0396 threadName = threading.currentThread().getName()
0397 threadName = (':' + threadName + ' '*(8-len(threadName)))
0398 else:
0399 threadName = ''
0400 msg = '%(n)2i%(threadName)s/%(name)s%(spaces)s%(sep)s %(s)s' % locals()
0401 self.debugWriter.write(msg)
0402
0403 def _executeRetry(self, conn, cursor, query):
0404 if self.debug:
0405 self.printDebug(conn, query, 'QueryR')
0406 return cursor.execute(query)
0407
0408 def _query(self, conn, s):
0409 if self.debug:
0410 self.printDebug(conn, s, 'Query')
0411 self._executeRetry(conn, conn.cursor(), s)
0412
0413 def query(self, s):
0414 return self._runWithConnection(self._query, s)
0415
0416 def _queryAll(self, conn, s):
0417 if self.debug:
0418 self.printDebug(conn, s, 'QueryAll')
0419 c = conn.cursor()
0420 self._executeRetry(conn, c, s)
0421 value = c.fetchall()
0422 if self.debugOutput:
0423 self.printDebug(conn, value, 'QueryAll', 'result')
0424 return value
0425
0426 def queryAll(self, s):
0427 return self._runWithConnection(self._queryAll, s)
0428
0429 def _queryAllDescription(self, conn, s):
0430 """
0431 Like queryAll, but returns (description, rows), where the
0432 description is cursor.description (which gives row types)
0433 """
0434 if self.debug:
0435 self.printDebug(conn, s, 'QueryAllDesc')
0436 c = conn.cursor()
0437 self._executeRetry(conn, c, s)
0438 value = c.fetchall()
0439 if self.debugOutput:
0440 self.printDebug(conn, value, 'QueryAll', 'result')
0441 return c.description, value
0442
0443 def queryAllDescription(self, s):
0444 return self._runWithConnection(self._queryAllDescription, s)
0445
0446 def _queryOne(self, conn, s):
0447 if self.debug:
0448 self.printDebug(conn, s, 'QueryOne')
0449 c = conn.cursor()
0450 self._executeRetry(conn, c, s)
0451 value = c.fetchone()
0452 if self.debugOutput:
0453 self.printDebug(conn, value, 'QueryOne', 'result')
0454 return value
0455
0456 def queryOne(self, s):
0457 return self._runWithConnection(self._queryOne, s)
0458
0459 def _insertSQL(self, table, names, values):
0460 return ("INSERT INTO %s (%s) VALUES (%s)" %
0461 (table, ', '.join(names),
0462 ', '.join([self.sqlrepr(v) for v in values])))
0463
0464 def transaction(self):
0465 return Transaction(self)
0466
0467 def queryInsertID(self, soInstance, id, names, values):
0468 return self._runWithConnection(self._queryInsertID, soInstance, id, names, values)
0469
0470 def iterSelect(self, select):
0471 return select.IterationClass(self, self.getConnection(),
0472 select, keepConnection=False)
0473
0474 def accumulateSelect(self, select, *expressions):
0475 """ Apply an accumulate function(s) (SUM, COUNT, MIN, AVG, MAX, etc...)
0476 to the select object.
0477 """
0478 q = select.queryForSelect().newItems(expressions).unlimited().orderBy(None)
0479 q = self.sqlrepr(q)
0480 val = self.queryOne(q)
0481 if len(expressions) == 1:
0482 val = val[0]
0483 return val
0484
0485 def queryForSelect(self, select):
0486 return self.sqlrepr(select.queryForSelect())
0487
0488 def _SO_createJoinTable(self, join):
0489 self.query(self._SO_createJoinTableSQL(join))
0490
0491 def _SO_createJoinTableSQL(self, join):
0492 return ('CREATE TABLE %s (\n%s %s,\n%s %s\n)' %
0493 (join.intermediateTable,
0494 join.joinColumn,
0495 self.joinSQLType(join),
0496 join.otherColumn,
0497 self.joinSQLType(join)))
0498
0499 def _SO_dropJoinTable(self, join):
0500 self.query("DROP TABLE %s" % join.intermediateTable)
0501
0502 def _SO_createIndex(self, soClass, index):
0503 self.query(self.createIndexSQL(soClass, index))
0504
0505 def createIndexSQL(self, soClass, index):
0506 assert 0, 'Implement in subclasses'
0507
0508 def createTable(self, soClass):
0509 createSql, constraints = self.createTableSQL(soClass)
0510 self.query(createSql)
0511
0512 return constraints
0513
0514 def createReferenceConstraints(self, soClass):
0515 refConstraints = [self.createReferenceConstraint(soClass, column) for column in soClass.sqlmeta.columnList if isinstance(column, col.SOForeignKey)]
0518 refConstraintDefs = [constraint for constraint in refConstraints if constraint]
0521 return refConstraintDefs
0522
0523 def createSQL(self, soClass):
0524 tableCreateSQLs = getattr(soClass.sqlmeta, 'createSQL', None)
0525 if tableCreateSQLs:
0526 assert isinstance(tableCreateSQLs,(str,list,dict,tuple)), (
0527 '%s.sqlmeta.createSQL must be a str, list, dict or tuple.' %
0528 (soClass.__name__))
0529 if isinstance(tableCreateSQLs, dict):
0530 tableCreateSQLs = tableCreateSQLs.get(soClass._connection.dbName, [])
0531 if isinstance(tableCreateSQLs, str):
0532 tableCreateSQLs = [tableCreateSQLs]
0533 if isinstance(tableCreateSQLs, tuple):
0534 tableCreateSQLs = list(tableCreateSQLs)
0535 assert isinstance(tableCreateSQLs,list), (
0536 'Unable to create a list from %s.sqlmeta.createSQL' %
0537 (soClass.__name__))
0538 return tableCreateSQLs or []
0539
0540 def createTableSQL(self, soClass):
0541 constraints = self.createReferenceConstraints(soClass)
0542 extraSQL = self.createSQL(soClass)
0543 createSql = ('CREATE TABLE %s (\n%s\n)' %
0544 (soClass.sqlmeta.table, self.createColumns(soClass)))
0545 return createSql, constraints + extraSQL
0546
0547 def createColumns(self, soClass):
0548 columnDefs = [self.createIDColumn(soClass)] + [self.createColumn(soClass, col)
0550 for col in soClass.sqlmeta.columnList]
0551 return ",\n".join([" %s" % c for c in columnDefs])
0552
0553 def createReferenceConstraint(self, soClass, col):
0554 assert 0, "Implement in subclasses"
0555
0556 def createColumn(self, soClass, col):
0557 assert 0, "Implement in subclasses"
0558
0559 def dropTable(self, tableName, cascade=False):
0560 self.query("DROP TABLE %s" % tableName)
0561
0562 def clearTable(self, tableName):
0563
0564
0565
0566
0567 self.query("DELETE FROM %s" % tableName)
0568
0569 def createBinary(self, value):
0570 """
0571 Create a binary object wrapper for the given database.
0572 """
0573
0574 return self.module.Binary(value)
0575
0576
0577
0578
0579
0580
0581
0582 def _SO_update(self, so, values):
0583 self.query("UPDATE %s SET %s WHERE %s = (%s)" %
0584 (so.sqlmeta.table,
0585 ", ".join(["%s = (%s)" % (dbName, self.sqlrepr(value))
0586 for dbName, value in values]),
0587 so.sqlmeta.idName,
0588 self.sqlrepr(so.id)))
0589
0590 def _SO_selectOne(self, so, columnNames):
0591 return self._SO_selectOneAlt(so, columnNames, so.q.id==so.id)
0592
0593
0594 def _SO_selectOneAlt(self, so, columnNames, condition):
0595 if columnNames:
0596 columns = [isinstance(x, basestring) and sqlbuilder.SQLConstant(x) or x for x in columnNames]
0597 else:
0598 columns = None
0599 return self.queryOne(self.sqlrepr(sqlbuilder.Select(columns,
0600 staticTables=[so.sqlmeta.table],
0601 clause=condition)))
0602
0603 def _SO_delete(self, so):
0604 self.query("DELETE FROM %s WHERE %s = (%s)" %
0605 (so.sqlmeta.table,
0606 so.sqlmeta.idName,
0607 self.sqlrepr(so.id)))
0608
0609 def _SO_selectJoin(self, soClass, column, value):
0610 return self.queryAll("SELECT %s FROM %s WHERE %s = (%s)" %
0611 (soClass.sqlmeta.idName,
0612 soClass.sqlmeta.table,
0613 column,
0614 self.sqlrepr(value)))
0615
0616 def _SO_intermediateJoin(self, table, getColumn, joinColumn, value):
0617 return self.queryAll("SELECT %s FROM %s WHERE %s = (%s)" %
0618 (getColumn,
0619 table,
0620 joinColumn,
0621 self.sqlrepr(value)))
0622
0623 def _SO_intermediateDelete(self, table, firstColumn, firstValue,
0624 secondColumn, secondValue):
0625 self.query("DELETE FROM %s WHERE %s = (%s) AND %s = (%s)" %
0626 (table,
0627 firstColumn,
0628 self.sqlrepr(firstValue),
0629 secondColumn,
0630 self.sqlrepr(secondValue)))
0631
0632 def _SO_intermediateInsert(self, table, firstColumn, firstValue,
0633 secondColumn, secondValue):
0634 self.query("INSERT INTO %s (%s, %s) VALUES (%s, %s)" %
0635 (table,
0636 firstColumn,
0637 secondColumn,
0638 self.sqlrepr(firstValue),
0639 self.sqlrepr(secondValue)))
0640
0641 def _SO_columnClause(self, soClass, kw):
0642 ops = {None: "IS"}
0643 data = {}
0644 if 'id' in kw:
0645 data[soClass.sqlmeta.idName] = kw.pop('id')
0646 for key, col in soClass.sqlmeta.columns.items():
0647 if key in kw:
0648 value = kw.pop(key)
0649 if col.from_python:
0650 value = col.from_python(value, sqlbuilder.SQLObjectState(soClass, connection=self))
0651 data[col.dbName] = value
0652 elif col.foreignName in kw:
0653 obj = kw.pop(col.foreignName)
0654 if isinstance(obj, main.SQLObject):
0655 data[col.dbName] = obj.id
0656 else:
0657 data[col.dbName] = obj
0658 if kw:
0659
0660 raise TypeError, "got an unexpected keyword argument(s): %r" % kw.keys()
0661
0662 if not data:
0663 return None
0664 return ' AND '.join(
0665 ['%s %s %s' %
0666 (dbName, ops.get(value, "="), self.sqlrepr(value))
0667 for dbName, value
0668 in data.items()])
0669
0670 def sqlrepr(self, v):
0671 return sqlrepr(v, self.dbName)
0672
0673 def __del__(self):
0674 self.close()
0675
0676 def close(self):
0677 if not hasattr(self, '_pool'):
0678
0679
0680 return
0681 if not self._pool:
0682 return
0683 self._poolLock.acquire()
0684 try:
0685 conns = self._pool[:]
0686 self._pool[:] = []
0687 for conn in conns:
0688 try:
0689 conn.close()
0690 except self.module.Error:
0691 pass
0692 del conn
0693 del conns
0694 finally:
0695 self._poolLock.release()
0696
0697 def createEmptyDatabase(self):
0698 """
0699 Create an empty database.
0700 """
0701 raise NotImplementedError
0702
0703class Iteration(object):
0704
0705 def __init__(self, dbconn, rawconn, select, keepConnection=False):
0706 self.dbconn = dbconn
0707 self.rawconn = rawconn
0708 self.select = select
0709 self.keepConnection = keepConnection
0710 self.cursor = rawconn.cursor()
0711 self.query = self.dbconn.queryForSelect(select)
0712 if dbconn.debug:
0713 dbconn.printDebug(rawconn, self.query, 'Select')
0714 self.dbconn._executeRetry(self.rawconn, self.cursor, self.query)
0715
0716 def __iter__(self):
0717 return self
0718
0719 def next(self):
0720 result = self.cursor.fetchone()
0721 if result is None:
0722 self._cleanup()
0723 raise StopIteration
0724 if result[0] is None:
0725 return None
0726 if self.select.ops.get('lazyColumns', 0):
0727 obj = self.select.sourceClass.get(result[0], connection=self.dbconn)
0728 return obj
0729 else:
0730 obj = self.select.sourceClass.get(result[0], selectResults=result[1:], connection=self.dbconn)
0731 return obj
0732
0733 def _cleanup(self):
0734 if getattr(self, 'query', None) is None:
0735
0736 return
0737 self.query = None
0738 if not self.keepConnection:
0739 self.dbconn.releaseConnection(self.rawconn)
0740 self.dbconn = self.rawconn = self.select = self.cursor = None
0741
0742 def __del__(self):
0743 self._cleanup()
0744
0745class Transaction(object):
0746
0747 def __init__(self, dbConnection):
0748
0749 self._obsolete = True
0750 self._dbConnection = dbConnection
0751 self._connection = dbConnection.getConnection()
0752 self._dbConnection._setAutoCommit(self._connection, 0)
0753 self.cache = CacheSet(cache=dbConnection.doCache)
0754 self._deletedCache = {}
0755 self._obsolete = False
0756
0757 def assertActive(self):
0758 assert not self._obsolete, "This transaction has already gone through ROLLBACK; begin another transaction"
0759
0760 def query(self, s):
0761 self.assertActive()
0762 return self._dbConnection._query(self._connection, s)
0763
0764 def queryAll(self, s):
0765 self.assertActive()
0766 return self._dbConnection._queryAll(self._connection, s)
0767
0768 def queryOne(self, s):
0769 self.assertActive()
0770 return self._dbConnection._queryOne(self._connection, s)
0771
0772 def queryInsertID(self, soInstance, id, names, values):
0773 self.assertActive()
0774 return self._dbConnection._queryInsertID(
0775 self._connection, soInstance, id, names, values)
0776
0777 def iterSelect(self, select):
0778 self.assertActive()
0779
0780
0781
0782
0783
0784 return iter(list(select.IterationClass(self, self._connection,
0785 select, keepConnection=True)))
0786
0787 def _SO_delete(self, inst):
0788 cls = inst.__class__.__name__
0789 if not cls in self._deletedCache:
0790 self._deletedCache[cls] = []
0791 self._deletedCache[cls].append(inst.id)
0792 meth = new.instancemethod(self._dbConnection._SO_delete.im_func, self, self.__class__)
0793 return meth(inst)
0794
0795 def commit(self, close=False):
0796 if self._obsolete:
0797
0798 return
0799 if self._dbConnection.debug:
0800 self._dbConnection.printDebug(self._connection, '', 'COMMIT')
0801 self._connection.commit()
0802 subCaches = [(sub[0], sub[1].allIDs()) for sub in self.cache.allSubCachesByClassNames().items()]
0803 subCaches.extend([(x[0], x[1]) for x in self._deletedCache.items()])
0804 for cls, ids in subCaches:
0805 for id in ids:
0806 inst = self._dbConnection.cache.tryGetByName(id, cls)
0807 if inst is not None:
0808 inst.expire()
0809 if close:
0810 self._makeObsolete()
0811
0812 def rollback(self):
0813 if self._obsolete:
0814
0815 return
0816 if self._dbConnection.debug:
0817 self._dbConnection.printDebug(self._connection, '', 'ROLLBACK')
0818 subCaches = [(sub, sub.allIDs()) for sub in self.cache.allSubCaches()]
0819 self._connection.rollback()
0820
0821 for subCache, ids in subCaches:
0822 for id in ids:
0823 inst = subCache.tryGet(id)
0824 if inst is not None:
0825 inst.expire()
0826 self._makeObsolete()
0827
0828 def __getattr__(self, attr):
0829 """
0830 If nothing else works, let the parent connection handle it.
0831 Except with this transaction as 'self'. Poor man's
0832 acquisition? Bad programming? Okay, maybe.
0833 """
0834 self.assertActive()
0835 attr = getattr(self._dbConnection, attr)
0836 try:
0837 func = attr.im_func
0838 except AttributeError:
0839 if isinstance(attr, ConnWrapper):
0840 return ConnWrapper(attr._soClass, self)
0841 else:
0842 return attr
0843 else:
0844 meth = new.instancemethod(func, self, self.__class__)
0845 return meth
0846
0847 def _makeObsolete(self):
0848 self._obsolete = True
0849 if self._dbConnection.autoCommit:
0850 self._dbConnection._setAutoCommit(self._connection, 1)
0851 self._dbConnection.releaseConnection(self._connection,
0852 explicit=True)
0853 self._connection = None
0854 self._deletedCache = {}
0855
0856 def begin(self):
0857
0858
0859 assert self._obsolete, "You cannot begin a new transaction session without rolling back this one"
0860 self._obsolete = False
0861 self._connection = self._dbConnection.getConnection()
0862 self._dbConnection._setAutoCommit(self._connection, 0)
0863
0864 def __del__(self):
0865 if self._obsolete:
0866 return
0867 self.rollback()
0868
0869 def close(self):
0870 raise TypeError('You cannot just close transaction - you should either call rollback(), commit() or commit(close=True) to close the underlying connection.')
0871
0872class ConnectionHub(object):
0873
0874 """
0875 This object serves as a hub for connections, so that you can pass
0876 in a ConnectionHub to a SQLObject subclass as though it was a
0877 connection, but actually bind a real database connection later.
0878 You can also bind connections on a per-thread basis.
0879
0880 You must hang onto the original ConnectionHub instance, as you
0881 cannot retrieve it again from the class or instance.
0882
0883 To use the hub, do something like::
0884
0885 hub = ConnectionHub()
0886 class MyClass(SQLObject):
0887 _connection = hub
0888
0889 hub.threadConnection = connectionFromURI('...')
0890
0891 """
0892
0893 def __init__(self):
0894 self.threadingLocal = threading_local()
0895
0896 def __get__(self, obj, type=None):
0897
0898
0899
0900 if (obj is not None) and '_connection' in obj.__dict__:
0901 return obj.__dict__['_connection']
0902 return self.getConnection()
0903
0904 def __set__(self, obj, value):
0905 obj.__dict__['_connection'] = value
0906
0907 def getConnection(self):
0908 try:
0909 return self.threadingLocal.connection
0910 except AttributeError:
0911 try:
0912 return self.processConnection
0913 except AttributeError:
0914 raise AttributeError(
0915 "No connection has been defined for this thread "
0916 "or process")
0917
0918 def doInTransaction(self, func, *args, **kw):
0919 """
0920 This routine can be used to run a function in a transaction,
0921 rolling the transaction back if any exception is raised from
0922 that function, and committing otherwise.
0923
0924 Use like::
0925
0926 sqlhub.doInTransaction(process_request, os.environ)
0927
0928 This will run ``process_request(os.environ)``. The return
0929 value will be preserved.
0930 """
0931
0932
0933 try:
0934 old_conn = self.threadingLocal.connection
0935 old_conn_is_threading = True
0936 except AttributeError:
0937 old_conn = self.processConnection
0938 old_conn_is_threading = False
0939 conn = old_conn.transaction()
0940 if old_conn_is_threading:
0941 self.threadConnection = conn
0942 else:
0943 self.processConnection = conn
0944 try:
0945 try:
0946 value = func(*args, **kw)
0947 except:
0948 conn.rollback()
0949 raise
0950 else:
0951 conn.commit(close=True)
0952 return value
0953 finally:
0954 if old_conn_is_threading:
0955 self.threadConnection = old_conn
0956 else:
0957 self.processConnection = old_conn
0958
0959 def _set_threadConnection(self, value):
0960 self.threadingLocal.connection = value
0961
0962 def _get_threadConnection(self):
0963 return self.threadingLocal.connection
0964
0965 def _del_threadConnection(self):
0966 del self.threadingLocal.connection
0967
0968 threadConnection = property(_get_threadConnection,
0969 _set_threadConnection,
0970 _del_threadConnection)
0971
0972class ConnectionURIOpener(object):
0973
0974 def __init__(self):
0975 self.schemeBuilders = {}
0976 self.instanceNames = {}
0977 self.cachedURIs = {}
0978
0979 def registerConnection(self, schemes, builder):
0980 for uriScheme in schemes:
0981 assert not uriScheme in self.schemeBuilders or self.schemeBuilders[uriScheme] is builder, "A driver has already been registered for the URI scheme %s" % uriScheme
0984 self.schemeBuilders[uriScheme] = builder
0985
0986 def registerConnectionInstance(self, inst):
0987 if inst.name:
0988 assert not inst.name in self.instanceNames or self.instanceNames[inst.name] is cls, "A instance has already been registered with the name %s" % inst.name
0991 assert inst.name.find(':') == -1, "You cannot include ':' in your class names (%r)" % cls.name
0992 self.instanceNames[inst.name] = inst
0993
0994 def connectionForURI(self, uri, oldUri=False, **args):
0995 if args:
0996 if '?' not in uri:
0997 uri += '?' + urllib.urlencode(args)
0998 else:
0999 uri += '&' + urllib.urlencode(args)
1000 if uri in self.cachedURIs:
1001 return self.cachedURIs[uri]
1002 if uri.find(':') != -1:
1003 scheme, rest = uri.split(':', 1)
1004 connCls = self.dbConnectionForScheme(scheme)
1005 if oldUri:
1006 conn = connCls.connectionFromOldURI(uri)
1007 else:
1008 conn = connCls.connectionFromURI(uri)
1009 else:
1010
1011 assert uri in self.instanceNames, "No SQLObject driver exists under the name %s" % uri
1013 conn = self.instanceNames[uri]
1014
1015 self.cachedURIs[uri] = conn
1016 return conn
1017
1018 def dbConnectionForScheme(self, scheme):
1019 assert scheme in self.schemeBuilders, (
1020 "No SQLObject driver exists for %s (only %s)"
1021 % (scheme, ', '.join(self.schemeBuilders.keys())))
1022 return self.schemeBuilders[scheme]()
1023
1024TheURIOpener = ConnectionURIOpener()
1025
1026registerConnection = TheURIOpener.registerConnection
1027registerConnectionInstance = TheURIOpener.registerConnectionInstance
1028connectionForURI = TheURIOpener.connectionForURI
1029dbConnectionForScheme = TheURIOpener.dbConnectionForScheme
1030
1031
1032import firebird
1033import maxdb
1034import mssql
1035import mysql
1036import postgres
1037import rdbhost
1038import sqlite
1039import sybase