0001from .sqlbuilder import *  # noqa
0002from .main import SQLObject
0003
0004
0005class ViewSQLObjectField(SQLObjectField):
0006    def __init__(self, alias, *arg):
0007        SQLObjectField.__init__(self, *arg)
0008        self.alias = alias
0009
0010    def __sqlrepr__(self, db):
0011        return self.alias + "." + self.fieldName
0012
0013    def tablesUsedImmediate(self):
0014        return [self.tableName]
0015
0016
0017class ViewSQLObjectTable(SQLObjectTable):
0018    FieldClass = ViewSQLObjectField
0019
0020    def __getattr__(self, attr):
0021        if attr == 'sqlmeta':
0022            raise AttributeError
0023        return SQLObjectTable.__getattr__(self, attr)
0024
0025    def _getattrFromID(self, attr):
0026        return self.FieldClass(self.soClass.sqlmeta.alias, self.tableName,
0027                               'id', attr, self.soClass, None)
0028
0029    def _getattrFromColumn(self, column, attr):
0030        return self.FieldClass(self.soClass.sqlmeta.alias, self.tableName,
0031                               column.name, attr, self.soClass, column)
0032
0033
0034class ViewSQLObject(SQLObject):
0035    """
0036    A SQLObject class that derives all it's values from other SQLObject
0037    classes. Columns on subclasses should use SQLBuilder constructs for dbName,
0038    and sqlmeta should specify:
0039
0040    * idName as a SQLBuilder construction
0041    * clause as SQLBuilder clause for specifying join conditions
0042      or other restrictions
0043    * table as an optional alternate name for the class alias
0044
0045    See test_views.py for simple examples.
0046    """
0047
0048    def __classinit__(cls, new_attrs):
0049        SQLObject.__classinit__(cls, new_attrs)
0050        # like is_base
0051        if cls.__name__ != 'ViewSQLObject':
0052            dbName = hasattr(cls, '_connection') and                   (cls._connection and cls._connection.dbName) or None
0054
0055            if getattr(cls.sqlmeta, 'table', None):
0056                cls.sqlmeta.alias = cls.sqlmeta.table
0057            else:
0058                cls.sqlmeta.alias =                       cls.sqlmeta.style.pythonClassToDBTable(cls.__name__)
0060            alias = cls.sqlmeta.alias
0061            columns = [ColumnAS(cls.sqlmeta.idName, 'id')]
0062            # {sqlrepr-key: [restriction, *aggregate-column]}
0063            aggregates = {'': [None]}
0064            inverseColumns = dict(
0065                [(y, x) for x, y in cls.sqlmeta.columns.items()])
0066            for col in cls.sqlmeta.columnList:
0067                n = inverseColumns[col]
0068                ascol = ColumnAS(col.dbName, n)
0069                if isAggregate(col.dbName):
0070                    restriction = getattr(col, 'aggregateClause', None)
0071                    if restriction:
0072                        restrictkey = sqlrepr(restriction, dbName)
0073                        aggregates[restrictkey] =                               aggregates.get(restrictkey, [restriction]) +                               [ascol]
0076                    else:
0077                        aggregates[''].append(ascol)
0078                else:
0079                    columns.append(ascol)
0080
0081            metajoin = getattr(cls.sqlmeta, 'join', NoDefault)
0082            clause = getattr(cls.sqlmeta, 'clause', NoDefault)
0083            select = Select(columns,
0084                            distinct=True,
0085                            # @@ LDO check if this really mattered
0086                            # for performance
0087                            # @@ Postgres (and MySQL?) extension!
0088                            # distinctOn=cls.sqlmeta.idName,
0089                            join=metajoin,
0090                            clause=clause)
0091
0092            aggregates = aggregates.values()
0093
0094            if aggregates != [[None]]:
0095                join = []
0096                last_alias = "%s_base" % alias
0097                last_id = "id"
0098                last = Alias(select, last_alias)
0099                columns = [
0100                    ColumnAS(SQLConstant("%s.%s" % (last_alias, x.expr2)),
0101                             x.expr2) for x in columns]
0102
0103                for i, agg in enumerate(aggregates):
0104                    restriction = agg[0]
0105                    if restriction is None:
0106                        restriction = clause
0107                    else:
0108                        restriction = AND(clause, restriction)
0109                    agg = agg[1:]
0110                    agg_alias = "%s_%s" % (alias, i)
0111                    agg_id = '%s_id' % agg_alias
0112                    if not last.q.alias.endswith('base'):
0113                        last = None
0114                    new_alias = Alias(Select(
0115                        [ColumnAS(cls.sqlmeta.idName, agg_id)] + agg,
0116                        groupBy=cls.sqlmeta.idName,
0117                        join=metajoin,
0118                        clause=restriction),
0119                        agg_alias)
0120                    agg_join = LEFTJOINOn(last, new_alias,
0121                                          "%s.%s = %s.%s" % (
0122                                              last_alias, last_id,
0123                                              agg_alias, agg_id))
0124
0125                    join.append(agg_join)
0126                    for col in agg:
0127                        columns.append(
0128                            ColumnAS(SQLConstant(
0129                                "%s.%s" % (agg_alias, col.expr2)),
0130                                col.expr2))
0131
0132                    last = new_alias
0133                    last_alias = agg_alias
0134                    last_id = agg_id
0135                select = Select(columns,
0136                                join=join)
0137
0138            cls.sqlmeta.table = Alias(select, alias)
0139            cls.q = ViewSQLObjectTable(cls)
0140            for n, col in cls.sqlmeta.columns.items():
0141                col.dbName = n
0142
0143
0144def isAggregate(expr):
0145    if isinstance(expr, SQLCall):
0146        return True
0147    if isinstance(expr, SQLOp):
0148        return isAggregate(expr.expr1) or isAggregate(expr.expr2)
0149    return False