0001import dbconnection
0002import joins
0003import main
0004import sqlbuilder
0005
0006__all__ = ['SelectResults']
0007
0008class SelectResults(object):
0009    IterationClass = dbconnection.Iteration
0010
0011    def __init__(self, sourceClass, clause, clauseTables=None,
0012                 **ops):
0013        self.sourceClass = sourceClass
0014        if clause is None or isinstance(clause, str) and clause == 'all':
0015            clause = sqlbuilder.SQLTrueClause
0016        if not isinstance(clause, sqlbuilder.SQLExpression):
0017            clause = sqlbuilder.SQLConstant(clause)
0018        self.clause = clause
0019        self.ops = ops
0020        if ops.get('orderBy', sqlbuilder.NoDefault) is sqlbuilder.NoDefault:
0021            ops['orderBy'] = sourceClass.sqlmeta.defaultOrder
0022        orderBy = ops['orderBy']
0023        if isinstance(orderBy, (tuple, list)):
0024            orderBy = map(self._mungeOrderBy, orderBy)
0025        else:
0026            orderBy = self._mungeOrderBy(orderBy)
0027        ops['dbOrderBy'] = orderBy
0028        if 'connection' in ops and ops['connection'] is None:
0029            del ops['connection']
0030        if ops.get('limit', None):
0031            assert not ops.get('start', None) and not ops.get('end', None),                  "'limit' cannot be used with 'start' or 'end'"
0033            ops["start"] = 0
0034            ops["end"] = ops.pop("limit")
0035
0036        tablesSet = sqlbuilder.tablesUsedSet(self.clause, self._getConnection().dbName)
0037        if clauseTables:
0038            for table in clauseTables:
0039                tablesSet.add(table)
0040        self.clauseTables = clauseTables
0041        # Explicitly post-adding-in sqlmeta.table, sqlbuilder.Select will handle sqlrepr'ing and dupes
0042        self.tables = list(tablesSet) + [sourceClass.sqlmeta.table]
0043
0044    def queryForSelect(self):
0045        columns = [self.sourceClass.q.id] + [getattr(self.sourceClass.q, x.name) for x in self.sourceClass.sqlmeta.columnList]
0046        query = sqlbuilder.Select(columns,
0047                                  where=self.clause,
0048                                  join=self.ops.get('join', sqlbuilder.NoDefault),
0049                                  distinct=self.ops.get('distinct',False),
0050                                  lazyColumns=self.ops.get('lazyColumns', False),
0051                                  start=self.ops.get('start', 0),
0052                                  end=self.ops.get('end', None),
0053                                  orderBy=self.ops.get('dbOrderBy',sqlbuilder.NoDefault),
0054                                  reversed=self.ops.get('reversed', False),
0055                                  staticTables=self.tables,
0056                                  forUpdate=self.ops.get('forUpdate', False))
0057        return query
0058
0059    def __repr__(self):
0060        return "<%s at %x>" % (self.__class__.__name__, id(self))
0061
0062    def _getConnection(self):
0063        return self.ops.get('connection') or self.sourceClass._connection
0064
0065    def __str__(self):
0066        conn = self._getConnection()
0067        return conn.queryForSelect(self)
0068
0069    def _mungeOrderBy(self, orderBy):
0070        if isinstance(orderBy, basestring) and orderBy.startswith('-'):
0071            orderBy = orderBy[1:]
0072            desc = True
0073        else:
0074            desc = False
0075        if isinstance(orderBy, basestring):
0076            if orderBy in self.sourceClass.sqlmeta.columns:
0077                val = getattr(self.sourceClass.q, self.sourceClass.sqlmeta.columns[orderBy].name)
0078                if desc:
0079                    return sqlbuilder.DESC(val)
0080                else:
0081                    return val
0082            else:
0083                orderBy = sqlbuilder.SQLConstant(orderBy)
0084                if desc:
0085                    return sqlbuilder.DESC(orderBy)
0086                else:
0087                    return orderBy
0088        else:
0089            return orderBy
0090
0091    def clone(self, **newOps):
0092        ops = self.ops.copy()
0093        ops.update(newOps)
0094        return self.__class__(self.sourceClass, self.clause,
0095                              self.clauseTables, **ops)
0096
0097    def orderBy(self, orderBy):
0098        return self.clone(orderBy=orderBy)
0099
0100    def connection(self, conn):
0101        return self.clone(connection=conn)
0102
0103    def limit(self, limit):
0104        return self[:limit]
0105
0106    def lazyColumns(self, value):
0107        return self.clone(lazyColumns=value)
0108
0109    def reversed(self):
0110        return self.clone(reversed=not self.ops.get('reversed', False))
0111
0112    def distinct(self):
0113        return self.clone(distinct=True)
0114
0115    def newClause(self, new_clause):
0116        return self.__class__(self.sourceClass, new_clause,
0117                              self.clauseTables, **self.ops)
0118
0119    def filter(self, filter_clause):
0120        if filter_clause is None:
0121            # None doesn't filter anything, it's just a no-op:
0122            return self
0123        clause = self.clause
0124        if isinstance(clause, basestring):
0125            clause = sqlbuilder.SQLConstant('(%s)' % clause)
0126        return self.newClause(sqlbuilder.AND(clause, filter_clause))
0127
0128    def __getitem__(self, value):
0129        if isinstance(value, slice):
0130            assert not value.step, "Slices do not support steps"
0131            if not value.start and not value.stop:
0132                # No need to copy, I'm immutable
0133                return self
0134
0135            # Negative indexes aren't handled (and everything we
0136            # don't handle ourselves we just create a list to
0137            # handle)
0138            if (value.start and value.start < 0)                  or (value.stop and value.stop < 0):
0140                if value.start:
0141                    if value.stop:
0142                        return list(self)[value.start:value.stop]
0143                    return list(self)[value.start:]
0144                return list(self)[:value.stop]
0145
0146
0147            if value.start:
0148                assert value.start >= 0
0149                start = self.ops.get('start', 0) + value.start
0150                if value.stop is not None:
0151                    assert value.stop >= 0
0152                    if value.stop < value.start:
0153                        # an empty result:
0154                        end = start
0155                    else:
0156                        end = value.stop + self.ops.get('start', 0)
0157                        if self.ops.get('end', None) is not None and                                   self.ops['end'] < end:
0159                            # truncated by previous slice:
0160                            end = self.ops['end']
0161                else:
0162                    end = self.ops.get('end', None)
0163            else:
0164                start = self.ops.get('start', 0)
0165                end = value.stop + start
0166                if self.ops.get('end', None) is not None                      and self.ops['end'] < end:
0168                    end = self.ops['end']
0169            return self.clone(start=start, end=end)
0170        else:
0171            if value < 0:
0172                return list(iter(self))[value]
0173            else:
0174                start = self.ops.get('start', 0) + value
0175                return list(self.clone(start=start, end=start+1))[0]
0176
0177    def __iter__(self):
0178        # @@: This could be optimized, using a simpler algorithm
0179        # since we don't have to worry about garbage collection,
0180        # etc., like we do with .lazyIter()
0181        return iter(list(self.lazyIter()))
0182
0183    def lazyIter(self):
0184        """
0185        Returns an iterator that will lazily pull rows out of the
0186        database and return SQLObject instances
0187        """
0188        conn = self._getConnection()
0189        return conn.iterSelect(self)
0190
0191    def accumulate(self, *expressions):
0192        """ Use accumulate expression(s) to select result
0193            using another SQL select through current
0194            connection.
0195            Return the accumulate result
0196        """
0197        conn = self._getConnection()
0198        exprs = []
0199        for expr in expressions:
0200            if not isinstance(expr, sqlbuilder.SQLExpression):
0201                expr = sqlbuilder.SQLConstant(expr)
0202            exprs.append(expr)
0203        return conn.accumulateSelect(self, *exprs)
0204
0205    def count(self):
0206        """ Counting elements of current select results """
0207        assert not self.ops.get('start') and not self.ops.get('end'),               "start/end/limit have no meaning with 'count'"
0209        assert not (self.ops.get('distinct') and (self.ops.get('start')
0210                                                  or self.ops.get('end'))),               "distinct-counting of sliced objects is not supported"
0212        if self.ops.get('distinct'):
0213            # Column must be specified, so we are using unique ID column.
0214            # COUNT(DISTINCT column) is supported by MySQL and PostgreSQL,
0215            # but not by SQLite. Perhaps more portable would be subquery:
0216            #  SELECT COUNT(*) FROM (SELECT DISTINCT id FROM table)
0217            count = self.accumulate('COUNT(DISTINCT %s)' % self._getConnection().sqlrepr(self.sourceClass.q.id))
0218        else:
0219            count = self.accumulate('COUNT(*)')
0220        if self.ops.get('start'):
0221            count -= self.ops['start']
0222        if self.ops.get('end'):
0223            count = min(self.ops['end'] - self.ops.get('start', 0), count)
0224        return count
0225
0226    def accumulateMany(self, *attributes):
0227        """ Making the expressions for count/sum/min/max/avg
0228            of a given select result attributes.
0229            `attributes` must be a list/tuple of pairs (func_name, attribute);
0230            `attribute` can be a column name (like 'a_column')
0231            or a dot-q attribute (like Table.q.aColumn)
0232        """
0233        expressions = []
0234        conn = self._getConnection()
0235        if self.ops.get('distinct'):
0236            distinct = 'DISTINCT '
0237        else:
0238            distinct = ''
0239        for func_name, attribute in attributes:
0240            if not isinstance(attribute, str):
0241                attribute = conn.sqlrepr(attribute)
0242            expression = '%s(%s%s)' % (func_name, distinct, attribute)
0243            expressions.append(expression)
0244        return self.accumulate(*expressions)
0245
0246    def accumulateOne(self, func_name, attribute):
0247        """ Making the sum/min/max/avg of a given select result attribute.
0248            `attribute` can be a column name (like 'a_column')
0249            or a dot-q attribute (like Table.q.aColumn)
0250        """
0251        return self.accumulateMany((func_name, attribute))
0252
0253    def sum(self, attribute):
0254        return self.accumulateOne("SUM", attribute)
0255
0256    def min(self, attribute):
0257        return self.accumulateOne("MIN", attribute)
0258
0259    def avg(self, attribute):
0260        return self.accumulateOne("AVG", attribute)
0261
0262    def max(self, attribute):
0263        return self.accumulateOne("MAX", attribute)
0264
0265    def getOne(self, default=sqlbuilder.NoDefault):
0266        """
0267        If a query is expected to only return a single value,
0268        using ``.getOne()`` will return just that value.
0269
0270        If not results are found, ``SQLObjectNotFound`` will be
0271        raised, unless you pass in a default value (like
0272        ``.getOne(None)``).
0273
0274        If more than one result is returned,
0275        ``SQLObjectIntegrityError`` will be raised.
0276        """
0277        results = list(self)
0278        if not results:
0279            if default is sqlbuilder.NoDefault:
0280                raise main.SQLObjectNotFound(
0281                    "No results matched the query for %s"
0282                    % self.sourceClass.__name__)
0283            return default
0284        if len(results) > 1:
0285            raise main.SQLObjectIntegrityError(
0286                "More than one result returned from query: %s"
0287                % results)
0288        return results[0]
0289
0290    def throughTo(self):
0291        class _throughTo_getter(object):
0292            def __init__(self, inst):
0293                self.sresult = inst
0294            def __getattr__(self, attr):
0295                return self.sresult._throughTo(attr)
0296        return _throughTo_getter(self)
0297    throughTo = property(throughTo)
0298
0299    def _throughTo(self, attr):
0300        otherClass = None
0301        orderBy = sqlbuilder.NoDefault
0302
0303        ref = self.sourceClass.sqlmeta.columns.get(attr.endswith('ID') and attr or attr+'ID', None)
0304        if ref and ref.foreignKey:
0305            otherClass, clause = self._throughToFK(ref)
0306        else:
0307            join = [x for x in self.sourceClass.sqlmeta.joins if x.joinMethodName==attr]
0308            if join:
0309                join = join[0]
0310                orderBy = join.orderBy
0311                if hasattr(join, 'otherColumn'):
0312                    otherClass, clause = self._throughToRelatedJoin(join)
0313                else:
0314                    otherClass, clause = self._throughToMultipleJoin(join)
0315
0316        if not otherClass:
0317            raise AttributeError("throughTo argument (got %s) should be name of foreignKey or SQL*Join in %s" % (attr, self.sourceClass))
0318
0319        return otherClass.select(clause,
0320                                orderBy=orderBy,
0321                                connection=self._getConnection())
0322
0323    def _throughToFK(self, col):
0324        otherClass = getattr(self.sourceClass, "_SO_class_"+col.foreignKey)
0325        colName = col.name
0326        query = self.queryForSelect().newItems([sqlbuilder.ColumnAS(getattr(self.sourceClass.q, colName), colName)]).orderBy(None).distinct()
0327        query = sqlbuilder.Alias(query, "%s_%s" % (self.sourceClass.__name__, col.name))
0328        return otherClass, otherClass.q.id==getattr(query.q, colName)
0329
0330    def _throughToMultipleJoin(self, join):
0331        otherClass = join.otherClass
0332        colName = join.soClass.sqlmeta.style.dbColumnToPythonAttr(join.joinColumn)
0333        query = self.queryForSelect().newItems([sqlbuilder.ColumnAS(self.sourceClass.q.id, 'id')]).orderBy(None).distinct()
0334        query = sqlbuilder.Alias(query, "%s_%s" % (self.sourceClass.__name__, join.joinMethodName))
0335        joinColumn = getattr(otherClass.q, colName)
0336        return otherClass, joinColumn==query.q.id
0337
0338    def _throughToRelatedJoin(self, join):
0339        otherClass = join.otherClass
0340        intTable = sqlbuilder.Table(join.intermediateTable)
0341        colName = join.joinColumn
0342        query = self.queryForSelect().newItems([sqlbuilder.ColumnAS(self.sourceClass.q.id, 'id')]).orderBy(None).distinct()
0343        query = sqlbuilder.Alias(query, "%s_%s" % (self.sourceClass.__name__, join.joinMethodName))
0344        clause = sqlbuilder.AND(otherClass.q.id == getattr(intTable, join.otherColumn),
0345                     getattr(intTable, colName) == query.q.id)
0346        return otherClass, clause