0001import sqlbuilder
0002import dbconnection
0003import main
0004import joins
0005
0006class SelectResults(object):
0007 IterationClass = dbconnection.Iteration
0008
0009 def __init__(self, sourceClass, clause, clauseTables=None,
0010 **ops):
0011 self.sourceClass = sourceClass
0012 if clause is None or isinstance(clause, str) and clause == 'all':
0013 clause = sqlbuilder.SQLTrueClause
0014 if not isinstance(clause, sqlbuilder.SQLExpression):
0015 clause = sqlbuilder.SQLConstant(clause)
0016 self.clause = clause
0017 self.ops = ops
0018 if ops.get('orderBy', sqlbuilder.NoDefault) is sqlbuilder.NoDefault:
0019 ops['orderBy'] = sourceClass.sqlmeta.defaultOrder
0020 orderBy = ops['orderBy']
0021 if isinstance(orderBy, (tuple, list)):
0022 orderBy = map(self._mungeOrderBy, orderBy)
0023 else:
0024 orderBy = self._mungeOrderBy(orderBy)
0025 ops['dbOrderBy'] = orderBy
0026 if ops.has_key('connection') and ops['connection'] is None:
0027 del ops['connection']
0028 if ops.get('limit', None):
0029 assert not ops.get('start', None) and not ops.get('end', None), "'limit' cannot be used with 'start' or 'end'"
0031 ops["start"] = 0
0032 ops["end"] = ops.pop("limit")
0033
0034 tablesSet = sqlbuilder.tablesUsedSet(self.clause, self._getConnection().dbName)
0035 if clauseTables:
0036 for table in clauseTables:
0037 tablesSet.add(table)
0038 self.clauseTables = clauseTables
0039
0040 self.tables = list(tablesSet) + [sourceClass.sqlmeta.table]
0041
0042 def queryForSelect(self):
0043 columns = [self.sourceClass.q.id] + [getattr(self.sourceClass.q, x.name) for x in self.sourceClass.sqlmeta.columnList]
0044 query = sqlbuilder.Select(columns,
0045 where=self.clause,
0046 join=self.ops.get('join', sqlbuilder.NoDefault),
0047 distinct=self.ops.get('distinct',False),
0048 lazyColumns=self.ops.get('lazyColumns', False),
0049 start=self.ops.get('start', 0),
0050 end=self.ops.get('end', None),
0051 orderBy=self.ops.get('dbOrderBy',sqlbuilder.NoDefault),
0052 reversed=self.ops.get('reversed', False),
0053 staticTables=self.tables,
0054 forUpdate=self.ops.get('forUpdate', False))
0055 return query
0056
0057 def __repr__(self):
0058 return "<%s at %x>" % (self.__class__.__name__, id(self))
0059
0060 def _getConnection(self):
0061 return self.ops.get('connection') or self.sourceClass._connection
0062
0063 def __str__(self):
0064 conn = self._getConnection()
0065 return conn.queryForSelect(self)
0066
0067 def _mungeOrderBy(self, orderBy):
0068 if isinstance(orderBy, str) and orderBy.startswith('-'):
0069 orderBy = orderBy[1:]
0070 desc = True
0071 else:
0072 desc = False
0073 if isinstance(orderBy, basestring):
0074 if orderBy in self.sourceClass.sqlmeta.columns:
0075 val = getattr(self.sourceClass.q, self.sourceClass.sqlmeta.columns[orderBy].name)
0076 if desc:
0077 return sqlbuilder.DESC(val)
0078 else:
0079 return val
0080 else:
0081 orderBy = sqlbuilder.SQLConstant(orderBy)
0082 if desc:
0083 return sqlbuilder.DESC(orderBy)
0084 else:
0085 return orderBy
0086 else:
0087 return orderBy
0088
0089 def clone(self, **newOps):
0090 ops = self.ops.copy()
0091 ops.update(newOps)
0092 return self.__class__(self.sourceClass, self.clause,
0093 self.clauseTables, **ops)
0094
0095 def orderBy(self, orderBy):
0096 return self.clone(orderBy=orderBy)
0097
0098 def connection(self, conn):
0099 return self.clone(connection=conn)
0100
0101 def limit(self, limit):
0102 return self[:limit]
0103
0104 def lazyColumns(self, value):
0105 return self.clone(lazyColumns=value)
0106
0107 def reversed(self):
0108 return self.clone(reversed=not self.ops.get('reversed', False))
0109
0110 def distinct(self):
0111 return self.clone(distinct=True)
0112
0113 def newClause(self, new_clause):
0114 return self.__class__(self.sourceClass, new_clause,
0115 self.clauseTables, **self.ops)
0116
0117 def filter(self, filter_clause):
0118 if filter_clause is None:
0119
0120 return self
0121 clause = self.clause
0122 if isinstance(clause, basestring):
0123 clause = sqlbuilder.SQLConstant('(%s)' % self.clause)
0124 return self.newClause(sqlbuilder.AND(clause, filter_clause))
0125
0126 def __getitem__(self, value):
0127 if isinstance(value, slice):
0128 assert not value.step, "Slices do not support steps"
0129 if not value.start and not value.stop:
0130
0131 return self
0132
0133
0134
0135
0136 if (value.start and value.start < 0) or (value.stop and value.stop < 0):
0138 if value.start:
0139 if value.stop:
0140 return list(self)[value.start:value.stop]
0141 return list(self)[value.start:]
0142 return list(self)[:value.stop]
0143
0144
0145 if value.start:
0146 assert value.start >= 0
0147 start = self.ops.get('start', 0) + value.start
0148 if value.stop is not None:
0149 assert value.stop >= 0
0150 if value.stop < value.start:
0151
0152 end = start
0153 else:
0154 end = value.stop + self.ops.get('start', 0)
0155 if self.ops.get('end', None) is not None and value['end'] < end:
0157
0158 end = self.ops['end']
0159 else:
0160 end = self.ops.get('end', None)
0161 else:
0162 start = self.ops.get('start', 0)
0163 end = value.stop + start
0164 if self.ops.get('end', None) is not None and self.ops['end'] < end:
0166 end = self.ops['end']
0167 return self.clone(start=start, end=end)
0168 else:
0169 if value < 0:
0170 return list(iter(self))[value]
0171 else:
0172 start = self.ops.get('start', 0) + value
0173 return list(self.clone(start=start, end=start+1))[0]
0174
0175 def __iter__(self):
0176
0177
0178
0179 return iter(list(self.lazyIter()))
0180
0181 def lazyIter(self):
0182 """
0183 Returns an iterator that will lazily pull rows out of the
0184 database and return SQLObject instances
0185 """
0186 conn = self._getConnection()
0187 return conn.iterSelect(self)
0188
0189 def accumulate(self, *expressions):
0190 """ Use accumulate expression(s) to select result
0191 using another SQL select through current
0192 connection.
0193 Return the accumulate result
0194 """
0195 conn = self._getConnection()
0196 exprs = []
0197 for expr in expressions:
0198 if not isinstance(expr, sqlbuilder.SQLExpression):
0199 expr = sqlbuilder.SQLConstant(expr)
0200 exprs.append(expr)
0201 return conn.accumulateSelect(self, *exprs)
0202
0203 def count(self):
0204 """ Counting elements of current select results """
0205 assert not self.ops.get('start') and not self.ops.get('end'), "start/end/limit have no meaning with 'count'"
0207 assert not (self.ops.get('distinct') and (self.ops.get('start')
0208 or self.ops.get('end'))), "distinct-counting of sliced objects is not supported"
0210 if self.ops.get('distinct'):
0211
0212
0213
0214
0215 count = self.accumulate('COUNT(DISTINCT %s)' % self._getConnection().sqlrepr(self.sourceClass.q.id))
0216 else:
0217 count = self.accumulate('COUNT(*)')
0218 if self.ops.get('start'):
0219 count -= self.ops['start']
0220 if self.ops.get('end'):
0221 count = min(self.ops['end'] - self.ops.get('start', 0), count)
0222 return count
0223
0224 def accumulateMany(self, *attributes):
0225 """ Making the expressions for count/sum/min/max/avg
0226 of a given select result attributes.
0227 `attributes` must be a list/tuple of pairs (func_name, attribute);
0228 `attribute` can be a column name (like 'a_column')
0229 or a dot-q attribute (like Table.q.aColumn)
0230 """
0231 expressions = []
0232 conn = self._getConnection()
0233 if self.ops.get('distinct'):
0234 distinct = 'DISTINCT '
0235 else:
0236 distinct = ''
0237 for func_name, attribute in attributes:
0238 if not isinstance(attribute, str):
0239 attribute = conn.sqlrepr(attribute)
0240 expression = '%s(%s%s)' % (func_name, distinct, attribute)
0241 expressions.append(expression)
0242 return self.accumulate(*expressions)
0243
0244 def accumulateOne(self, func_name, attribute):
0245 """ Making the sum/min/max/avg of a given select result attribute.
0246 `attribute` can be a column name (like 'a_column')
0247 or a dot-q attribute (like Table.q.aColumn)
0248 """
0249 return self.accumulateMany((func_name, attribute))
0250
0251 def sum(self, attribute):
0252 return self.accumulateOne("SUM", attribute)
0253
0254 def min(self, attribute):
0255 return self.accumulateOne("MIN", attribute)
0256
0257 def avg(self, attribute):
0258 return self.accumulateOne("AVG", attribute)
0259
0260 def max(self, attribute):
0261 return self.accumulateOne("MAX", attribute)
0262
0263 def getOne(self, default=sqlbuilder.NoDefault):
0264 """
0265 If a query is expected to only return a single value,
0266 using ``.getOne()`` will return just that value.
0267
0268 If not results are found, ``SQLObjectNotFound`` will be
0269 raised, unless you pass in a default value (like
0270 ``.getOne(None)``).
0271
0272 If more than one result is returned,
0273 ``SQLObjectIntegrityError`` will be raised.
0274 """
0275 results = list(self)
0276 if not results:
0277 if default is sqlbuilder.NoDefault:
0278 raise main.SQLObjectNotFound(
0279 "No results matched the query for %s"
0280 % self.sourceClass.__name__)
0281 return default
0282 if len(results) > 1:
0283 raise main.SQLObjectIntegrityError(
0284 "More than one result returned from query: %s"
0285 % results)
0286 return results[0]
0287
0288 def throughTo(self):
0289 class _throughTo_getter(object):
0290 def __init__(self, inst):
0291 self.sresult = inst
0292 def __getattr__(self, attr):
0293 return self.sresult._throughTo(attr)
0294 return _throughTo_getter(self)
0295 throughTo = property(throughTo)
0296
0297 def _throughTo(self, attr):
0298 otherClass = None
0299 orderBy = sqlbuilder.NoDefault
0300
0301 ref = self.sourceClass.sqlmeta.columns.get(attr.endswith('ID') and attr or attr+'ID', None)
0302 if ref and ref.foreignKey:
0303 otherClass, clause = self._throughToFK(ref)
0304 else:
0305 join = [x for x in self.sourceClass.sqlmeta.joins if x.joinMethodName==attr]
0306 if join:
0307 join = join[0]
0308 orderBy = join.orderBy
0309 if hasattr(join, 'otherColumn'):
0310 otherClass, clause = self._throughToRelatedJoin(join)
0311 else:
0312 otherClass, clause = self._throughToMultipleJoin(join)
0313
0314 if not otherClass:
0315 raise AttributeError("throughTo argument (got %s) should be name of foreignKey or SQL*Join in %s" % (attr, self.sourceClass))
0316
0317 return otherClass.select(clause,
0318 orderBy=orderBy,
0319 connection=self._getConnection())
0320
0321 def _throughToFK(self, col):
0322 otherClass = getattr(self.sourceClass, "_SO_class_"+col.foreignKey)
0323 colName = col.name
0324 query = self.queryForSelect().newItems([sqlbuilder.ColumnAS(getattr(self.sourceClass.q, colName), colName)]).orderBy(None).distinct()
0325 query = sqlbuilder.Alias(query, "%s_%s" % (self.sourceClass.__name__, col.name))
0326 return otherClass, otherClass.q.id==getattr(query.q, colName)
0327
0328 def _throughToMultipleJoin(self, join):
0329 otherClass = join.otherClass
0330 colName = join.soClass.sqlmeta.style.dbColumnToPythonAttr(join.joinColumn)
0331 query = self.queryForSelect().newItems([sqlbuilder.ColumnAS(self.sourceClass.q.id, 'id')]).orderBy(None).distinct()
0332 query = sqlbuilder.Alias(query, "%s_%s" % (self.sourceClass.__name__, join.joinMethodName))
0333 joinColumn = getattr(otherClass.q, colName)
0334 return otherClass, joinColumn==query.q.id
0335
0336 def _throughToRelatedJoin(self, join):
0337 otherClass = join.otherClass
0338 intTable = sqlbuilder.Table(join.intermediateTable)
0339 colName = join.joinColumn
0340 query = self.queryForSelect().newItems([sqlbuilder.ColumnAS(self.sourceClass.q.id, 'id')]).orderBy(None).distinct()
0341 query = sqlbuilder.Alias(query, "%s_%s" % (self.sourceClass.__name__, join.joinMethodName))
0342 clause = sqlbuilder.AND(otherClass.q.id == getattr(intTable, join.otherColumn),
0343 getattr(intTable, colName) == query.q.id)
0344 return otherClass, clause
0345
0346__all__ = ['SelectResults']