0001from sqlapi import sql
0002
0003__all__ = ['SelectResults']
0004
0005class SelectResults(object):
0006
0007    def __init__(self, column_specs, clause, return_single=True,
0008                 connection=None):
0009        self.column_specs = column_specs
0010        self.clause = clause
0011        self.connection = connection
0012        if return_single:
0013            assert len(column_specs) == 1, (
0014                "You can only use return_single when a single spec "
0015                "is passed in (got %r)" % column_specs)
0016        self.return_single = return_single
0017        self.columns = []
0018        for cols, builder in self.column_specs:
0019            self.columns.extend(cols)
0020        self.sql_select = sql.Select(self.columns, clause)
0021
0022    def __iter__(self):
0023        rows = self.do_query()
0024        for row in rows:
0025            yield self.produce_from_row(row)
0026
0027    def do_query(self):
0028        cur = self.connection.cursor()
0029        cur.execute(self.sql_select)
0030        rows = cur.fetchall()
0031        cur.close()
0032        return rows
0033
0034    def produce_from_row(self, row):
0035        if self.return_single:
0036            return self.produce_spec_from_row(
0037                self.column_specs[0], row)[0]
0038        else:
0039            results = []
0040            rest = row
0041            for spec in self.column_specs:
0042                result, rest = self.produce_spec_from_row(
0043                    spec, rest)
0044                results.append(result)
0045            return results
0046
0047    def produce_spec_from_row(self, spec, row):
0048        columns, builder = spec
0049        if len(columns) < len(row):
0050            raise AssertionError(
0051                "Need %i columns to fill %r; only %i left"
0052                % (len(columns), builder, len(row)))
0053        value = builder(row[:len(columns)])
0054        return value, row[len(columns):]
0055
0056    def __getitem__(self, item):
0057        return list(self)[item]
0058
0059    def count(self):
0060        COUNT = sql.funcs.COUNT
0061        star = sql.star_from([c for c, b in self.column_specs])
0062        query = sql.Select([COUNT(star)], self.clause)
0063        cur = self.connection.cursor()
0064        cur.execute(query)
0065        (count,) = cur.fetchone()
0066        cur.close()
0067        return count