0001__all__ = ['sqlexpr', 'SQLExpression', 'sqlrepr', 'Table',
0002           'Alias', 'sqlparams', 'Column', 'funcs',
0003           'Select', 'Update', 'Delete', 'Insert']
0004
0005class sqlexpr(str):
0006    def __repr__(self):
0007        return '<SQL %s>' % self
0008    def __sql__(self, dbsql):
0009        return self
0010
0011def _comma_join(seq, join=', '):
0012    """
0013    Returns the sequence with ``sqlexpr`` commas between items
0014    """
0015    result = []
0016    first = True
0017    for item in seq:
0018        if not first:
0019            result.append(sqlexpr(join))
0020        else:
0021            first = False
0022        result.append(item)
0023    return result
0024
0025class SQLPiece(object):
0026
0027    def __add__(self, other):
0028        return SQLOp(self, '+', other)
0029    def __radd__(self, other):
0030        return SQLOp(other, '+', self)
0031    def __sub__(self, other):
0032        return SQLOp(self, '-', other)
0033    def __rsub__(self, other):
0034        return SQLOp(other, '-', self)
0035    def __mul__(self, other):
0036        return SQLOp(self, '*', other)
0037    def __rmul__(self, other):
0038        return SQLOp(other, '*', self)
0039    def __div__(self, other):
0040        return SQLOp(self, '/', other)
0041    def __rdiv__(self, other):
0042        return SQLOp(other, '/', self)
0043    def __pos__(self):
0044        return SQLUnaryOp('+', self)
0045    def __neg__(self):
0046        return SQLUnaryOp('-', self)
0047    def __pow__(self, other):
0048        return self._expr_class(sqlexpr('POW'))(self, other)
0049    def __rpow__(self, other):
0050        return self._expr_class(sqlexpr("POW"))(other, self)
0051    def __abs__(self):
0052        return self._expr_class(sqlexpr("ABS"))(self)
0053    def __mod__(self, other):
0054        return self._expr_class(sqlexpr("MOD"))(self, other)
0055    def __rmod__(self, other):
0056        return self._expr_class(sqlexpr("MOD"))(other, self)
0057
0058    def __lt__(self, other):
0059        return SQLOp(self, '<', other)
0060    def __le__(self, other):
0061        return SQLOp(self, '<=', other)
0062    def __gt__(self, other):
0063        return SQLOp(self, '>', other)
0064    def __ge__(self, other):
0065        return SQLOp(self, '>=', other)
0066    def __eq__(self, other):
0067        if other is None:
0068            return SQLPostfixOp(self, 'IS NULL')
0069        else:
0070            return SQLOp(self, '=', other)
0071    def __ne__(self, other):
0072        if other is None:
0073            return SQLPostfixOp(self, 'IS NOT NULL')
0074        else:
0075            return SQLOp(self, '<>', other)
0076
0077    def __and__(self, other):
0078        return SQLOp(self, 'AND', other)
0079    def __rand__(self, other):
0080        return SQLOp(other, 'AND', self)
0081    def __or__(self, other):
0082        return SQLOp(self, 'OR', other)
0083    def __ror__(self, other):
0084        return SQLOp(other, 'OR', self)
0085    def __invert__(self):
0086        return SQLUnaryOp('NOT', self)
0087
0088    def __call__(self, *args):
0089        all = [self, sqlexpr('(')]
0090        for arg in args[:-1]:
0091            all.append(arg)
0092            all.append(sqlexpr(', '))
0093        all.append(args[-1])
0094        all.append(sqlexpr(')'))
0095        return self._expr_class(*all)
0096
0097    def __repr__(self):
0098        return '<%s %s>' % (
0099            self.__class__.__name__, sqlrepr(self))
0100
0101    def __cmp__(self, other):
0102        raise VersionError, "Python 2.1+ required"
0103    def __rcmp__(self, other):
0104        raise VersionError, "Python 2.1+ required"
0105
0106    def startswith(self, s):
0107        return startswith(self, s)
0108    def endswith(self, s):
0109        return endswith(self, s)
0110    def contains(self, s):
0111        return contains_string(self, s)
0112
0113class SQLExpression(SQLPiece):
0114    def __init__(self, *components, **kw):
0115        self.components = components
0116        if 'expression' in kw:
0117            self.components = (
0118                (sqlexpr('('),)
0119                + list(self.components)
0120                + (sqlexpr(')'),))
0121        if kw:
0122            raise TypeError("Keyword arguments %s not allowed"
0123                            % kw)
0124
0125    def __sql_components__(self, dbsql):
0126        return self.components
0127
0128SQLPiece._expr_class = SQLExpression
0129
0130class SQLOp(SQLExpression):
0131
0132    def __init__(self, expr1, operator, expr2):
0133        self.expr1 = expr1
0134        self.expr2 = expr2
0135        self.operator = operator
0136
0137    def __sql_components__(self, dbsql):
0138        return [sqlexpr('('),
0139                self.expr1,
0140                sqlexpr(' %s ' % dbsql.sql_operator(self.operator)),
0141                self.expr2,
0142                sqlexpr(')')]
0143
0144class SQLUnaryOp(SQLExpression):
0145
0146    def __init__(self, operator, expr):
0147        self.operator = operator
0148        self.expr = expr
0149
0150    def __sql_components__(self, dbsql):
0151        return [sqlexpr(dbsql.sql_unary_operator(self.operator)+' '),
0152                self.expr]
0153
0154class SQLPostfixOp(SQLExpression):
0155
0156    def __init__(self, expr, operator):
0157        self.expr = expr
0158        self.operator = operator
0159
0160    def __sql_components__(self, dbsql):
0161        return [self.expr,
0162                sqlexpr(' '+dbsql.sql_postfix_operator(self.operator))]
0163
0164class Table(SQLPiece):
0165
0166    def __init__(self, table):
0167        self.table_name = table
0168
0169    def __sql_from__(self, dbsql):
0170        return {self.table_name: self}
0171
0172    def __sql_components__(self, dbsql):
0173        return [sqlexpr(dbsql.sql_table(self.table_name))]
0174
0175    def __repr__(self):
0176        return '<%s %s>' % (self.__class__.__name__, self.table_name)
0177
0178    def __getattr__(self, column):
0179        if column.startswith('__'):
0180            raise AttributeError
0181        return self[column]
0182
0183    def __getitem__(self, column):
0184        if isinstance(column, (int, long)):
0185            raise KeyError(
0186                "You cannot iterate over tables (with %r)" % column)
0187        return Column(self, column)
0188
0189    #@classmethod
0190    def maybe_name(cls, value):
0191        if isinstance(value, cls):
0192            return value
0193        else:
0194            return cls(value)
0195
0196    maybe_name = classmethod(maybe_name)
0197
0198class Alias(Table):
0199
0200    def __init__(self, table, alias):
0201        self.orig_table_name = table
0202        self.alias = alias
0203        Table.__init__(self, alias)
0204
0205    def __sql_from__(self, dbsql):
0206        return {self.alias:
0207                SQLExpression(Table(self.orig_table_name),
0208                              sqlexpr(' AS '),
0209                              Table(self.table_name))}
0210
0211    def __repr__(self):
0212        return '<%s %s AS %s>' % (
0213            self.__class__.__name__, self.orig_table_name, self.table_name)
0214
0215class Column(SQLPiece):
0216
0217    def __init__(self, table, column):
0218        self.column_name = column
0219        self.table = table
0220
0221    def __sql_from__(self, dbsql):
0222        return self.table.__sql_from__(dbsql)
0223
0224    def __sql_components__(self, dbsql):
0225        parts = self.table.__sql_components__(dbsql)
0226        parts.extend(
0227            [sqlexpr('.'), sqlexpr(dbsql.sql_column(self.column_name))])
0228        return parts
0229
0230def AND(expr, *exprs):
0231    if not exprs:
0232        return expr
0233    return AND(expr & exprs[0], *exprs[1:])
0234
0235def OR(expr, *exprs):
0236    if not exprs:
0237        return expr
0238    return OR(expr | exprs[0], *exprs[1:])
0239
0240def NOT(expr):
0241    return ~expr
0242
0243def LIKE(expr, pat):
0244    return SQLOp(expr, 'LIKE', pat, expression=True)
0245
0246def ILIKE(expr, pat):
0247    return SQLOp(expr, 'ILIKE', pat, expression=True)
0248
0249def endswith(expr, string):
0250    if isinstance(string, basestring):
0251        return LIKE(expr, '%' + _like_quote(string))
0252    else:
0253        return LIKE(expr, strconcat('%', string))
0254
0255def startswith(expr, string):
0256    if isinstance(string, basestring):
0257        return LIKE(expr, _like_quote(string) + '%')
0258    else:
0259        return LIKE(expr, strconcat(string, '%'))
0260
0261def strcontains(expr, string):
0262    if isinstance(string, basestring):
0263        return LIKE(expr, '%' + _like_quote(string) + '%')
0264    else:
0265        return LIKE(expr, strconcat('%', string, '%'))
0266
0267def _like_quote(string):
0268    return string.replace('%', '%%')
0269
0270def strconcat(string, *args):
0271    if not args:
0272        return string
0273    parts = [string]
0274    for part in args:
0275        parts.append(literal(' || '))
0276        parts.append(part)
0277    return SQLExpression(*parts)
0278
0279class star_from(object):
0280    """
0281    For things like ``COUNT(*)``, using
0282    ``COUNT(star_from(FROM_TABLES))``
0283    """
0284
0285    def __init__(self, *parts):
0286        self.parts = parts
0287
0288    def __sql_components__(self, dbsql):
0289        return [sqlexpr('*')]
0290
0291    def __sql_from__(self, dbsql):
0292        result = {}
0293        _collect_tables_to(self.parts, result, dbsql)
0294        return result
0295
0296class _Funcs(object):
0297
0298    def __getattr__(self, func_name):
0299        if func_name.startswith('__'):
0300            raise AttributeError
0301        return Func(func_name)
0302
0303funcs = _Funcs()
0304
0305class Func(SQLPiece):
0306
0307    def __init__(self, func_name):
0308        self.func_name = func_name
0309
0310    def __sql_components__(self, dbsql):
0311        return [sqlexpr(dbsql.sql_function(self.func_name))]
0312
0313############################################################
0314## Statements
0315############################################################
0316
0317class Insert(object):
0318
0319    def __init__(self, table, values):
0320        self.table = table
0321        self.values = values
0322
0323    def __sql_components__(self, dbsql):
0324        items = self.values.items()
0325        result = [sqlexpr('INSERT INTO '),
0326                  sqlexpr(dbsql.sql_table(self.table)),
0327                  sqlexpr(' ('),
0328                  ]
0329        result.extend(_comma_join([
0330            sqlexpr(name) for name, value in items]))
0331        result.append(sqlexpr(') VALUES ('))
0332        result.extend(_comma_join([
0333            value for name, value in items]))
0334        result.append(sqlexpr(')'))
0335        return result
0336
0337class Delete(object):
0338
0339    def __init__(self, table, clause):
0340        self.table = Table.maybe_name(table)
0341        self.clause = clause
0342
0343    def __sql_components__(self, dbsql):
0344        result = [sqlexpr('DELETE FROM '),
0345                  self.table]
0346        if self.clause is not True:
0347            result.extend([
0348                sqlexpr(' WHERE '),
0349                self.clause])
0350        return result
0351
0352class Update(object):
0353
0354    def __init__(self, table, values, clause):
0355        self.table = Table.maybe_name(table)
0356        self.values = values
0357        self.clause = clause
0358
0359    def __sql_components__(self, dbsql):
0360        result = [sqlexpr('UPDATE '),
0361                  self.table,
0362                  sqlexpr(' SET '),
0363                  ]
0364        parts = []
0365        for name, value in self.values:
0366            parts.append([
0367                sqlexpr(name),
0368                sqlexpr(' = '),
0369                value])
0370        result.extend(_comma_join(parts, ',\n  '))
0371        if self.clause is not True:
0372            result.extend([sqlexpr(' WHERE '), self.clause])
0373        return result
0374
0375class Select(object):
0376
0377    def __init__(self, columns, clause, tables=None,
0378                 order_by=None, distinct=False):
0379        self.columns = columns
0380        self.clause = clause
0381        if tables is not None:
0382            self.tables = [Table.maybe_name(t) for t in tables]
0383        else:
0384            self.tables = None
0385        self.order_by = order_by
0386        self.distinct = distinct
0387
0388    def __sql_components__(self, dbsql):
0389        if self.tables is None:
0390            tables = _collect_tables([self.clause, self.columns], dbsql)
0391        else:
0392            tables = self.tables
0393        result = [sqlexpr('SELECT ')]
0394        if self.distinct:
0395            result.append(sqlexpr('DISTINCT '))
0396        result.extend(_comma_join(self.columns))
0397        result.append(sqlexpr(' FROM '))
0398        result.extend(_comma_join(tables))
0399        if self.clause is not True:
0400            result.append(sqlexpr(' WHERE '))
0401            result.append(self.clause)
0402        if self.order_by is not None:
0403            result.append(sqlexpr(' ORDER BY '))
0404            result.extend(_comma_join(self.order_by))
0405        return result
0406
0407def _collect_tables(queries, dbsql):
0408    result = {}
0409    _collect_tables_to(queries, result, dbsql)
0410    return result.values()
0411
0412def _collect_tables_to(queries, result, dbsql):
0413    if isinstance(queries, (list, tuple)):
0414        for item in queries:
0415            _collect_tables_to(item, result, dbsql)
0416        return
0417    if hasattr(queries, '__sql_from__'):
0418        sql_from = queries.__sql_from__(dbsql)
0419        result.update(sql_from)
0420    elif hasattr(queries, '__sql_components__'):
0421        for item in queries.__sql_components__(dbsql):
0422            _collect_tables_to(item, result, dbsql)
0423
0424class Desc(object):
0425
0426    oper = 'DESC'
0427
0428    def __init__(self, expr):
0429        self.expr = expr
0430
0431    def __sql_components__(self, dbsql):
0432        result = list(self.expr.__sql_components__(dbsql))
0433        result.append(sqlexpr(' '+self.oper))
0434        return result
0435
0436############################################################
0437## DDL
0438############################################################
0439
0440class CreateTable(object):
0441
0442    def __init__(self, table_name, columns):
0443        # @@: Should this take **kw?
0444        self.table_name = table_name
0445        self.columns = columns
0446
0447    def __sql_components__(self, dbsql):
0448        parts = [
0449            sqlexpr('CREATE TABLE '),
0450            sqlexpr(dbsql.sql_table(self.table_name)),
0451            sqlexpr(' (\n  '),
0452            ]
0453        parts.extend(_comma_join(
0454            [column.__sql_components__(dbsql)
0455             for column in self.columns],
0456            ',\n  '))
0457        parts.append(sqlexpr('\n)'))
0458        return parts
0459
0460    def __repr__(self):
0461        return '<%s for %s with %s>' % (
0462            self.__class__.__name__, self.table_name,
0463            ', '.join(map(repr, self.columns)))
0464
0465class DropTable(object):
0466
0467    def __init__(self, table_name):
0468        # @@: Cascade
0469        self.table_name = table_name
0470
0471    def __sql_components__(self, dbsql):
0472        return [
0473            sqlexpr('DROP TABLE '),
0474            sqlexpr(dbsql.sql_table(self.table_name)),
0475            ]
0476
0477    def __repr__(self):
0478        return '<%s for %s>' % (
0479            self.__class__.__name__, self.table_name)
0480
0481class ColumnDefinition(object):
0482
0483    def __init__(self, name, sqltype, not_null=False,
0484                 auto_increment=False, primary_key=False,
0485                 **args):
0486        self.name = name
0487        self.sqltype = sqltype
0488        self.not_null = not_null
0489        self.auto_increment = auto_increment
0490        self.primary_key = primary_key
0491        self.args = args
0492
0493    def __sql_components__(self, dbsql):
0494        s = [
0495            sqlexpr(dbsql.sql_column(self.name)),
0496            sqlexpr(' '),
0497            sqlexpr(dbsql.sql_column_create(self.name, self.sqltype, self.args, self.sqltype)),
0498            ]
0499        if self.auto_increment and dbsql.auto_increment:
0500            text = dbsql.sql_auto_increment_def()
0501            if text is not None:
0502                s.append(sqlexpr(' '+text))
0503        if self.primary_key:
0504            s.append(sqlexpr(' PRIMARY KEY'))
0505        if self.not_null:
0506            s.append(sqlexpr(' NOT NULL'))
0507        return s
0508
0509    def __repr__(self):
0510        return '<%s for %s>' % (self.__class__.__name__, sqlrepr(self))
0511
0512############################################################
0513## Rendering expressions
0514############################################################
0515
0516def default_dbsql():
0517    from sqlapi.backend.generic import DatabasePlugin
0518    return DatabasePlugin()
0519
0520def traverse_expr(expr, dbsql, literal_callback):
0521    """
0522    Yield all the bits of the given expression
0523    """
0524    if isinstance(expr, (list, tuple)):
0525        for item in expr:
0526            for subitem in traverse_expr(item, dbsql, literal_callback):
0527                yield subitem
0528    elif hasattr(expr, '__sql__'):
0529        yield expr.__sql__(dbsql)
0530    elif hasattr(expr, '__sql_components__'):
0531        for item in expr.__sql_components__(dbsql):
0532            for subitem in traverse_expr(item, dbsql, literal_callback):
0533                yield subitem
0534    else:
0535        yield literal_callback(expr)
0536
0537def sqlrepr(expr, dbsql=None):
0538    """
0539    Returns a string representation 
0540    """
0541    dbsql = dbsql or default_dbsql()
0542    exprs = traverse_expr(expr, dbsql, dbsql.sql_literal)
0543    s = ''.join(exprs)
0544    return s
0545
0546def sqlparams(expr, dbsql=None):
0547    dbsql = dbsql or default_dbsql()
0548    markergen = iter(dbsql.param_mark_generator())
0549    s = []
0550    all_params = None
0551    def make_marker(s):
0552        return (s,)
0553    for item in traverse_expr(expr, dbsql, make_marker):
0554        if isinstance(item, tuple):
0555            item_value = dbsql.sql_literal_param(item[0])
0556            marker, name = markergen.next()
0557            if name is None:
0558                if all_params is None:
0559                    all_params = [item_value]
0560                else:
0561                    all_params.append(item_value)
0562            else:
0563                if all_params is None:
0564                    all_params = {name: item_value}
0565                else:
0566                    all_params[name] = item_value
0567            s.append(marker)
0568        else:
0569            s.append(item)
0570    if isinstance(all_params, list):
0571        # @@: Is this needed?
0572        all_params = tuple(all_params)
0573    return (''.join(s), all_params)