0001import sqlbuilder
0002import dbconnection
0003import main
0004import joins
0005
0006class SelectResults(object):
0007    IterationClass = dbconnection.Iteration
0008
0009    def __init__(self, sourceClass, clause, clauseTables=None,
0010                 **ops):
0011        self.sourceClass = sourceClass
0012        if clause is None or isinstance(clause, str) and clause == 'all':
0013            clause = sqlbuilder.SQLTrueClause
0014        if not isinstance(clause, sqlbuilder.SQLExpression):
0015            clause = sqlbuilder.SQLConstant(clause)
0016        self.clause = clause
0017        self.ops = ops
0018        if ops.get('orderBy', sqlbuilder.NoDefault) is sqlbuilder.NoDefault:
0019            ops['orderBy'] = sourceClass.sqlmeta.defaultOrder
0020        orderBy = ops['orderBy']
0021        if isinstance(orderBy, (tuple, list)):
0022            orderBy = map(self._mungeOrderBy, orderBy)
0023        else:
0024            orderBy = self._mungeOrderBy(orderBy)
0025        ops['dbOrderBy'] = orderBy
0026        if ops.has_key('connection') and ops['connection'] is None:
0027            del ops['connection']
0028        if ops.get('limit', None):
0029            assert not ops.get('start', None) and not ops.get('end', None),                  "'limit' cannot be used with 'start' or 'end'"
0031            ops["start"] = 0
0032            ops["end"] = ops.pop("limit")
0033
0034        tablesSet = sqlbuilder.tablesUsedSet(self.clause, self._getConnection().dbName)
0035        if clauseTables:
0036            for table in clauseTables:
0037                tablesSet.add(table)
0038        self.clauseTables = clauseTables
0039        # Explicitly post-adding-in sqlmeta.table, sqlbuilder.Select will handle sqlrepr'ing and dupes
0040        self.tables = list(tablesSet) + [sourceClass.sqlmeta.table]
0041
0042    def queryForSelect(self):
0043        columns = [self.sourceClass.q.id] + [getattr(self.sourceClass.q, x.name) for x in self.sourceClass.sqlmeta.columnList]
0044        query = sqlbuilder.Select(columns,
0045                                  where=self.clause,
0046                                  join=self.ops.get('join', sqlbuilder.NoDefault),
0047                                  distinct=self.ops.get('distinct',False),
0048                                  lazyColumns=self.ops.get('lazyColumns', False),
0049                                  start=self.ops.get('start', 0),
0050                                  end=self.ops.get('end', None),
0051                                  orderBy=self.ops.get('dbOrderBy',sqlbuilder.NoDefault),
0052                                  reversed=self.ops.get('reversed', False),
0053                                  staticTables=self.tables,
0054                                  forUpdate=self.ops.get('forUpdate', False))
0055        return query
0056
0057    def __repr__(self):
0058        return "<%s at %x>" % (self.__class__.__name__, id(self))
0059
0060    def _getConnection(self):
0061        return self.ops.get('connection') or self.sourceClass._connection
0062
0063    def __str__(self):
0064        conn = self._getConnection()
0065        return conn.queryForSelect(self)
0066
0067    def _mungeOrderBy(self, orderBy):
0068        if isinstance(orderBy, str) and orderBy.startswith('-'):
0069            orderBy = orderBy[1:]
0070            desc = True
0071        else:
0072            desc = False
0073        if isinstance(orderBy, basestring):
0074            if orderBy in self.sourceClass.sqlmeta.columns:
0075                val = getattr(self.sourceClass.q, self.sourceClass.sqlmeta.columns[orderBy].name)
0076                if desc:
0077                    return sqlbuilder.DESC(val)
0078                else:
0079                    return val
0080            else:
0081                orderBy = sqlbuilder.SQLConstant(orderBy)
0082                if desc:
0083                    return sqlbuilder.DESC(orderBy)
0084                else:
0085                    return orderBy
0086        else:
0087            return orderBy
0088
0089    def clone(self, **newOps):
0090        ops = self.ops.copy()
0091        ops.update(newOps)
0092        return self.__class__(self.sourceClass, self.clause,
0093                              self.clauseTables, **ops)
0094
0095    def orderBy(self, orderBy):
0096        return self.clone(orderBy=orderBy)
0097
0098    def connection(self, conn):
0099        return self.clone(connection=conn)
0100
0101    def limit(self, limit):
0102        return self[:limit]
0103
0104    def lazyColumns(self, value):
0105        return self.clone(lazyColumns=value)
0106
0107    def reversed(self):
0108        return self.clone(reversed=not self.ops.get('reversed', False))
0109
0110    def distinct(self):
0111        return self.clone(distinct=True)
0112
0113    def newClause(self, new_clause):
0114        return self.__class__(self.sourceClass, new_clause,
0115                              self.clauseTables, **self.ops)
0116
0117    def filter(self, filter_clause):
0118        if filter_clause is None:
0119            # None doesn't filter anything, it's just a no-op:
0120            return self
0121        clause = self.clause
0122        if isinstance(clause, basestring):
0123            clause = sqlbuilder.SQLConstant('(%s)' % self.clause)
0124        return self.newClause(sqlbuilder.AND(clause, filter_clause))
0125
0126    def __getitem__(self, value):
0127        if isinstance(value, slice):
0128            assert not value.step, "Slices do not support steps"
0129            if not value.start and not value.stop:
0130                # No need to copy, I'm immutable
0131                return self
0132
0133            # Negative indexes aren't handled (and everything we
0134            # don't handle ourselves we just create a list to
0135            # handle)
0136            if (value.start and value.start < 0)                  or (value.stop and value.stop < 0):
0138                if value.start:
0139                    if value.stop:
0140                        return list(self)[value.start:value.stop]
0141                    return list(self)[value.start:]
0142                return list(self)[:value.stop]
0143
0144
0145            if value.start:
0146                assert value.start >= 0
0147                start = self.ops.get('start', 0) + value.start
0148                if value.stop is not None:
0149                    assert value.stop >= 0
0150                    if value.stop < value.start:
0151                        # an empty result:
0152                        end = start
0153                    else:
0154                        end = value.stop + self.ops.get('start', 0)
0155                        if self.ops.get('end', None) is not None                              and value['end'] < end:
0157                            # truncated by previous slice:
0158                            end = self.ops['end']
0159                else:
0160                    end = self.ops.get('end', None)
0161            else:
0162                start = self.ops.get('start', 0)
0163                end = value.stop + start
0164                if self.ops.get('end', None) is not None                      and self.ops['end'] < end:
0166                    end = self.ops['end']
0167            return self.clone(start=start, end=end)
0168        else:
0169            if value < 0:
0170                return list(iter(self))[value]
0171            else:
0172                start = self.ops.get('start', 0) + value
0173                return list(self.clone(start=start, end=start+1))[0]
0174
0175    def __iter__(self):
0176        # @@: This could be optimized, using a simpler algorithm
0177        # since we don't have to worry about garbage collection,
0178        # etc., like we do with .lazyIter()
0179        return iter(list(self.lazyIter()))
0180
0181    def lazyIter(self):
0182        """
0183        Returns an iterator that will lazily pull rows out of the
0184        database and return SQLObject instances
0185        """
0186        conn = self._getConnection()
0187        return conn.iterSelect(self)
0188
0189    def accumulate(self, *expressions):
0190        """ Use accumulate expression(s) to select result
0191            using another SQL select through current
0192            connection.
0193            Return the accumulate result
0194        """
0195        conn = self._getConnection()
0196        exprs = []
0197        for expr in expressions:
0198            if not isinstance(expr, sqlbuilder.SQLExpression):
0199                expr = sqlbuilder.SQLConstant(expr)
0200            exprs.append(expr)
0201        return conn.accumulateSelect(self, *exprs)
0202
0203    def count(self):
0204        """ Counting elements of current select results """
0205        assert not self.ops.get('start') and not self.ops.get('end'),               "start/end/limit have no meaning with 'count'"
0207        assert not (self.ops.get('distinct') and (self.ops.get('start')
0208                                                  or self.ops.get('end'))),               "distinct-counting of sliced objects is not supported"
0210        if self.ops.get('distinct'):
0211            # Column must be specified, so we are using unique ID column.
0212            # COUNT(DISTINCT column) is supported by MySQL and PostgreSQL,
0213            # but not by SQLite. Perhaps more portable would be subquery:
0214            #  SELECT COUNT(*) FROM (SELECT DISTINCT id FROM table)
0215            count = self.accumulate('COUNT(DISTINCT %s)' % self._getConnection().sqlrepr(self.sourceClass.q.id))
0216        else:
0217            count = self.accumulate('COUNT(*)')
0218        if self.ops.get('start'):
0219            count -= self.ops['start']
0220        if self.ops.get('end'):
0221            count = min(self.ops['end'] - self.ops.get('start', 0), count)
0222        return count
0223
0224    def accumulateMany(self, *attributes):
0225        """ Making the expressions for count/sum/min/max/avg
0226            of a given select result attributes.
0227            `attributes` must be a list/tuple of pairs (func_name, attribute);
0228            `attribute` can be a column name (like 'a_column')
0229            or a dot-q attribute (like Table.q.aColumn)
0230        """
0231        expressions = []
0232        conn = self._getConnection()
0233        if self.ops.get('distinct'):
0234            distinct = 'DISTINCT '
0235        else:
0236            distinct = ''
0237        for func_name, attribute in attributes:
0238            if not isinstance(attribute, str):
0239                attribute = conn.sqlrepr(attribute)
0240            expression = '%s(%s%s)' % (func_name, distinct, attribute)
0241            expressions.append(expression)
0242        return self.accumulate(*expressions)
0243
0244    def accumulateOne(self, func_name, attribute):
0245        """ Making the sum/min/max/avg of a given select result attribute.
0246            `attribute` can be a column name (like 'a_column')
0247            or a dot-q attribute (like Table.q.aColumn)
0248        """
0249        return self.accumulateMany((func_name, attribute))
0250
0251    def sum(self, attribute):
0252        return self.accumulateOne("SUM", attribute)
0253
0254    def min(self, attribute):
0255        return self.accumulateOne("MIN", attribute)
0256
0257    def avg(self, attribute):
0258        return self.accumulateOne("AVG", attribute)
0259
0260    def max(self, attribute):
0261        return self.accumulateOne("MAX", attribute)
0262
0263    def getOne(self, default=sqlbuilder.NoDefault):
0264        """
0265        If a query is expected to only return a single value,
0266        using ``.getOne()`` will return just that value.
0267
0268        If not results are found, ``SQLObjectNotFound`` will be
0269        raised, unless you pass in a default value (like
0270        ``.getOne(None)``).
0271
0272        If more than one result is returned,
0273        ``SQLObjectIntegrityError`` will be raised.
0274        """
0275        results = list(self)
0276        if not results:
0277            if default is sqlbuilder.NoDefault:
0278                raise main.SQLObjectNotFound(
0279                    "No results matched the query for %s"
0280                    % self.sourceClass.__name__)
0281            return default
0282        if len(results) > 1:
0283            raise main.SQLObjectIntegrityError(
0284                "More than one result returned from query: %s"
0285                % results)
0286        return results[0]
0287
0288    def throughTo(self):
0289        class _throughTo_getter(object):
0290            def __init__(self, inst):
0291                self.sresult = inst
0292            def __getattr__(self, attr):
0293                return self.sresult._throughTo(attr)
0294        return _throughTo_getter(self)
0295    throughTo = property(throughTo)
0296
0297    def _throughTo(self, attr):
0298        otherClass = None
0299        orderBy = sqlbuilder.NoDefault
0300
0301        ref = self.sourceClass.sqlmeta.columns.get(attr.endswith('ID') and attr or attr+'ID', None)
0302        if ref and ref.foreignKey:
0303            otherClass, clause = self._throughToFK(ref)
0304        else:
0305            join = [x for x in self.sourceClass.sqlmeta.joins if x.joinMethodName==attr]
0306            if join:
0307                join = join[0]
0308                orderBy = join.orderBy
0309                if hasattr(join, 'otherColumn'):
0310                    otherClass, clause = self._throughToRelatedJoin(join)
0311                else:
0312                    otherClass, clause = self._throughToMultipleJoin(join)
0313
0314        if not otherClass:
0315            raise AttributeError("throughTo argument (got %s) should be name of foreignKey or SQL*Join in %s" % (attr, self.sourceClass))
0316
0317        return otherClass.select(clause,
0318                                orderBy=orderBy,
0319                                connection=self._getConnection())
0320
0321    def _throughToFK(self, col):
0322        otherClass = getattr(self.sourceClass, "_SO_class_"+col.foreignKey)
0323        colName = col.name
0324        query = self.queryForSelect().newItems([sqlbuilder.ColumnAS(getattr(self.sourceClass.q, colName), colName)]).orderBy(None).distinct()
0325        query = sqlbuilder.Alias(query, "%s_%s" % (self.sourceClass.__name__, col.name))
0326        return otherClass, otherClass.q.id==getattr(query.q, colName)
0327
0328    def _throughToMultipleJoin(self, join):
0329        otherClass = join.otherClass
0330        colName = join.soClass.sqlmeta.style.dbColumnToPythonAttr(join.joinColumn)
0331        query = self.queryForSelect().newItems([sqlbuilder.ColumnAS(self.sourceClass.q.id, 'id')]).orderBy(None).distinct()
0332        query = sqlbuilder.Alias(query, "%s_%s" % (self.sourceClass.__name__, join.joinMethodName))
0333        joinColumn = getattr(otherClass.q, colName)
0334        return otherClass, joinColumn==query.q.id
0335
0336    def _throughToRelatedJoin(self, join):
0337        otherClass = join.otherClass
0338        intTable = sqlbuilder.Table(join.intermediateTable)
0339        colName = join.joinColumn
0340        query = self.queryForSelect().newItems([sqlbuilder.ColumnAS(self.sourceClass.q.id, 'id')]).orderBy(None).distinct()
0341        query = sqlbuilder.Alias(query, "%s_%s" % (self.sourceClass.__name__, join.joinMethodName))
0342        clause = sqlbuilder.AND(otherClass.q.id == getattr(intTable, join.otherColumn),
0343                     getattr(intTable, colName) == query.q.id)
0344        return otherClass, clause
0345
0346__all__ = ['SelectResults']