0001from . import dbconnection
0002from . import sqlbuilder
0003from .compat import string_type
0004
0005
0006__all__ = ['SelectResults']
0007
0008
0009class SelectResults(object):
0010    IterationClass = dbconnection.Iteration
0011
0012    def __init__(self, sourceClass, clause, clauseTables=None,
0013                 **ops):
0014        self.sourceClass = sourceClass
0015        if clause is None or isinstance(clause, str) and clause == 'all':
0016            clause = sqlbuilder.SQLTrueClause
0017        if not isinstance(clause, sqlbuilder.SQLExpression):
0018            clause = sqlbuilder.SQLConstant(clause)
0019        self.clause = clause
0020        self.ops = ops
0021        if ops.get('orderBy', sqlbuilder.NoDefault) is sqlbuilder.NoDefault:
0022            ops['orderBy'] = sourceClass.sqlmeta.defaultOrder
0023        orderBy = ops['orderBy']
0024        if isinstance(orderBy, (tuple, list)):
0025            orderBy = list(map(self._mungeOrderBy, orderBy))
0026        else:
0027            orderBy = self._mungeOrderBy(orderBy)
0028        ops['dbOrderBy'] = orderBy
0029        if 'connection' in ops and ops['connection'] is None:
0030            del ops['connection']
0031        if ops.get('limit', None):
0032            assert not ops.get('start', None) and not ops.get('end', None),                   "'limit' cannot be used with 'start' or 'end'"
0034            ops["start"] = 0
0035            ops["end"] = ops.pop("limit")
0036
0037        tablesSet = sqlbuilder.tablesUsedSet(self.clause,
0038                                             self._getConnection().dbName)
0039        if clauseTables:
0040            for table in clauseTables:
0041                tablesSet.add(table)
0042        self.clauseTables = clauseTables
0043        # Explicitly post-adding-in sqlmeta.table,
0044        # sqlbuilder.Select will handle sqlrepr'ing and dupes.
0045        self.tables = list(tablesSet) + [sourceClass.sqlmeta.table]
0046
0047    def queryForSelect(self):
0048        columns = [self.sourceClass.q.id] +               [getattr(self.sourceClass.q, x.name)
0050             for x in self.sourceClass.sqlmeta.columnList]
0051        query = sqlbuilder.Select(columns,
0052                                  where=self.clause,
0053                                  join=self.ops.get(
0054                                      'join', sqlbuilder.NoDefault),
0055                                  distinct=self.ops.get('distinct', False),
0056                                  lazyColumns=self.ops.get(
0057                                      'lazyColumns', False),
0058                                  start=self.ops.get('start', 0),
0059                                  end=self.ops.get('end', None),
0060                                  orderBy=self.ops.get(
0061                                      'dbOrderBy', sqlbuilder.NoDefault),
0062                                  reversed=self.ops.get('reversed', False),
0063                                  staticTables=self.tables,
0064                                  forUpdate=self.ops.get('forUpdate', False))
0065        return query
0066
0067    def __repr__(self):
0068        return "<%s at %x>" % (self.__class__.__name__, id(self))
0069
0070    def _getConnection(self):
0071        return self.ops.get('connection') or self.sourceClass._connection
0072
0073    def __str__(self):
0074        conn = self._getConnection()
0075        return conn.queryForSelect(self)
0076
0077    def _mungeOrderBy(self, orderBy):
0078        if isinstance(orderBy, string_type) and orderBy.startswith('-'):
0079            orderBy = orderBy[1:]
0080            desc = True
0081        else:
0082            desc = False
0083        if isinstance(orderBy, string_type):
0084            if orderBy in self.sourceClass.sqlmeta.columns:
0085                val = getattr(self.sourceClass.q,
0086                              self.sourceClass.sqlmeta.columns[orderBy].name)
0087                if desc:
0088                    return sqlbuilder.DESC(val)
0089                else:
0090                    return val
0091            else:
0092                orderBy = sqlbuilder.SQLConstant(orderBy)
0093                if desc:
0094                    return sqlbuilder.DESC(orderBy)
0095                else:
0096                    return orderBy
0097        else:
0098            return orderBy
0099
0100    def clone(self, **newOps):
0101        ops = self.ops.copy()
0102        ops.update(newOps)
0103        return self.__class__(self.sourceClass, self.clause,
0104                              self.clauseTables, **ops)
0105
0106    def orderBy(self, orderBy):
0107        return self.clone(orderBy=orderBy)
0108
0109    def connection(self, conn):
0110        return self.clone(connection=conn)
0111
0112    def limit(self, limit):
0113        return self[:limit]
0114
0115    def lazyColumns(self, value):
0116        return self.clone(lazyColumns=value)
0117
0118    def reversed(self):
0119        return self.clone(reversed=not self.ops.get('reversed', False))
0120
0121    def distinct(self):
0122        return self.clone(distinct=True)
0123
0124    def newClause(self, new_clause):
0125        return self.__class__(self.sourceClass, new_clause,
0126                              self.clauseTables, **self.ops)
0127
0128    def filter(self, filter_clause):
0129        if filter_clause is None:
0130            # None doesn't filter anything, it's just a no-op:
0131            return self
0132        clause = self.clause
0133        if isinstance(clause, string_type):
0134            clause = sqlbuilder.SQLConstant('(%s)' % clause)
0135        return self.newClause(sqlbuilder.AND(clause, filter_clause))
0136
0137    def __getitem__(self, value):
0138        if isinstance(value, slice):
0139            assert not value.step, "Slices do not support steps"
0140            if not value.start and not value.stop:
0141                # No need to copy, I'm immutable
0142                return self
0143
0144            # Negative indexes aren't handled (and everything we
0145            # don't handle ourselves we just create a list to
0146            # handle)
0147            if (value.start and value.start < 0)                  or (value.stop and value.stop < 0):
0149                if value.start:
0150                    if value.stop:
0151                        return list(self)[value.start:value.stop]
0152                    return list(self)[value.start:]
0153                return list(self)[:value.stop]
0154
0155            if value.start:
0156                assert value.start >= 0
0157                start = self.ops.get('start', 0) + value.start
0158                if value.stop is not None:
0159                    assert value.stop >= 0
0160                    if value.stop < value.start:
0161                        # an empty result:
0162                        end = start
0163                    else:
0164                        end = value.stop + self.ops.get('start', 0)
0165                        if self.ops.get('end', None) is not None and                                   self.ops['end'] < end:
0167                            # truncated by previous slice:
0168                            end = self.ops['end']
0169                else:
0170                    end = self.ops.get('end', None)
0171            else:
0172                start = self.ops.get('start', 0)
0173                end = value.stop + start
0174                if self.ops.get('end', None) is not None                      and self.ops['end'] < end:
0176                    end = self.ops['end']
0177            return self.clone(start=start, end=end)
0178        else:
0179            if value < 0:
0180                return list(iter(self))[value]
0181            else:
0182                start = self.ops.get('start', 0) + value
0183                return list(self.clone(start=start, end=start + 1))[0]
0184
0185    def __iter__(self):
0186        # @@: This could be optimized, using a simpler algorithm
0187        # since we don't have to worry about garbage collection,
0188        # etc., like we do with .lazyIter()
0189        return iter(list(self.lazyIter()))
0190
0191    def lazyIter(self):
0192        """
0193        Returns an iterator that will lazily pull rows out of the
0194        database and return SQLObject instances
0195        """
0196        conn = self._getConnection()
0197        return conn.iterSelect(self)
0198
0199    def accumulate(self, *expressions):
0200        """ Use accumulate expression(s) to select result
0201            using another SQL select through current
0202            connection.
0203            Return the accumulate result
0204        """
0205        conn = self._getConnection()
0206        exprs = []
0207        for expr in expressions:
0208            if not isinstance(expr, sqlbuilder.SQLExpression):
0209                expr = sqlbuilder.SQLConstant(expr)
0210            exprs.append(expr)
0211        return conn.accumulateSelect(self, *exprs)
0212
0213    def count(self):
0214        """ Counting elements of current select results """
0215        assert not self.ops.get('start') and not self.ops.get('end'),               "start/end/limit have no meaning with 'count'"
0217        assert not (self.ops.get('distinct') and
0218                    (self.ops.get('start') or self.ops.get('end'))),               "distinct-counting of sliced objects is not supported"
0220        if self.ops.get('distinct'):
0221            # Column must be specified, so we are using unique ID column.
0222            # COUNT(DISTINCT column) is supported by MySQL and PostgreSQL,
0223            # but not by SQLite. Perhaps more portable would be subquery:
0224            #  SELECT COUNT(*) FROM (SELECT DISTINCT id FROM table)
0225            count = self.accumulate(
0226                'COUNT(DISTINCT %s)' % self._getConnection().sqlrepr(
0227                    self.sourceClass.q.id))
0228        else:
0229            count = self.accumulate('COUNT(*)')
0230        if self.ops.get('start'):
0231            count -= self.ops['start']
0232        if self.ops.get('end'):
0233            count = min(self.ops['end'] - self.ops.get('start', 0), count)
0234        return count
0235
0236    def accumulateMany(self, *attributes):
0237        """ Making the expressions for count/sum/min/max/avg
0238            of a given select result attributes.
0239            `attributes` must be a list/tuple of pairs (func_name, attribute);
0240            `attribute` can be a column name (like 'a_column')
0241            or a dot-q attribute (like Table.q.aColumn)
0242        """
0243        expressions = []
0244        conn = self._getConnection()
0245        if self.ops.get('distinct'):
0246            distinct = 'DISTINCT '
0247        else:
0248            distinct = ''
0249        for func_name, attribute in attributes:
0250            if not isinstance(attribute, str):
0251                attribute = conn.sqlrepr(attribute)
0252            expression = '%s(%s%s)' % (func_name, distinct, attribute)
0253            expressions.append(expression)
0254        return self.accumulate(*expressions)
0255
0256    def accumulateOne(self, func_name, attribute):
0257        """ Making the sum/min/max/avg of a given select result attribute.
0258            `attribute` can be a column name (like 'a_column')
0259            or a dot-q attribute (like Table.q.aColumn)
0260        """
0261        return self.accumulateMany((func_name, attribute))
0262
0263    def sum(self, attribute):
0264        return self.accumulateOne("SUM", attribute)
0265
0266    def min(self, attribute):
0267        return self.accumulateOne("MIN", attribute)
0268
0269    def avg(self, attribute):
0270        return self.accumulateOne("AVG", attribute)
0271
0272    def max(self, attribute):
0273        return self.accumulateOne("MAX", attribute)
0274
0275    def getOne(self, default=sqlbuilder.NoDefault):
0276        """
0277        If a query is expected to only return a single value,
0278        using ``.getOne()`` will return just that value.
0279
0280        If not results are found, ``SQLObjectNotFound`` will be
0281        raised, unless you pass in a default value (like
0282        ``.getOne(None)``).
0283
0284        If more than one result is returned,
0285        ``SQLObjectIntegrityError`` will be raised.
0286        """
0287        from . import main
0288        results = list(self)
0289        if not results:
0290            if default is sqlbuilder.NoDefault:
0291                raise main.SQLObjectNotFound(
0292                    "No results matched the query for %s"
0293                    % self.sourceClass.__name__)
0294            return default
0295        if len(results) > 1:
0296            raise main.SQLObjectIntegrityError(
0297                "More than one result returned from query: %s"
0298                % results)
0299        return results[0]
0300
0301    def throughTo(self):
0302        class _throughTo_getter(object):
0303            def __init__(self, inst):
0304                self.sresult = inst
0305
0306            def __getattr__(self, attr):
0307                return self.sresult._throughTo(attr)
0308        return _throughTo_getter(self)
0309    throughTo = property(throughTo)
0310
0311    def _throughTo(self, attr):
0312        otherClass = None
0313        orderBy = sqlbuilder.NoDefault
0314
0315        ref = self.sourceClass.sqlmeta.columns.get(
0316            attr.endswith('ID') and attr or attr + 'ID', None)
0317        if ref and ref.foreignKey:
0318            otherClass, clause = self._throughToFK(ref)
0319        else:
0320            join = [x for x in self.sourceClass.sqlmeta.joins
0321                    if x.joinMethodName == attr]
0322            if join:
0323                join = join[0]
0324                orderBy = join.orderBy
0325                if hasattr(join, 'otherColumn'):
0326                    otherClass, clause = self._throughToRelatedJoin(join)
0327                else:
0328                    otherClass, clause = self._throughToMultipleJoin(join)
0329
0330        if not otherClass:
0331            raise AttributeError(
0332                "throughTo argument (got %s) should be "
0333                "name of foreignKey or SQL*Join in %s" % (attr,
0334                                                          self.sourceClass))
0335
0336        return otherClass.select(clause,
0337                                 orderBy=orderBy,
0338                                 connection=self._getConnection())
0339
0340    def _throughToFK(self, col):
0341        otherClass = getattr(self.sourceClass, "_SO_class_" + col.foreignKey)
0342        colName = col.name
0343        query = self.queryForSelect().newItems([
0344            sqlbuilder.ColumnAS(getattr(self.sourceClass.q, colName), colName)
0345        ]).orderBy(None).distinct()
0346        query = sqlbuilder.Alias(query,
0347                                 "%s_%s" % (self.sourceClass.__name__,
0348                                            col.name))
0349        return otherClass, otherClass.q.id == getattr(query.q, colName)
0350
0351    def _throughToMultipleJoin(self, join):
0352        otherClass = join.otherClass
0353        colName = join.soClass.sqlmeta.style.              dbColumnToPythonAttr(join.joinColumn)
0355        query = self.queryForSelect().newItems(
0356            [sqlbuilder.ColumnAS(self.sourceClass.q.id, 'id')]).              orderBy(None).distinct()
0358        query = sqlbuilder.Alias(query,
0359                                 "%s_%s" % (self.sourceClass.__name__,
0360                                            join.joinMethodName))
0361        joinColumn = getattr(otherClass.q, colName)
0362        return otherClass, joinColumn == query.q.id
0363
0364    def _throughToRelatedJoin(self, join):
0365        otherClass = join.otherClass
0366        intTable = sqlbuilder.Table(join.intermediateTable)
0367        colName = join.joinColumn
0368        query = self.queryForSelect().newItems(
0369            [sqlbuilder.ColumnAS(self.sourceClass.q.id, 'id')]).              orderBy(None).distinct()
0371        query = sqlbuilder.Alias(query,
0372                                 "%s_%s" % (self.sourceClass.__name__,
0373                                            join.joinMethodName))
0374        clause = sqlbuilder.AND(
0375            otherClass.q.id == getattr(intTable, join.otherColumn),
0376            getattr(intTable, colName) == query.q.id)
0377        return otherClass, clause