0001from array import array
0002import datetime
0003from decimal import Decimal
0004import time
0005import sys
0006from .compat import buffer_type
0007if sys.version_info[0] < 3:
0008    from types import ClassType, InstanceType, NoneType
0009else:
0010    # Use suitable aliases for now
0011    ClassType = type
0012    NoneType = type(None)
0013    # This is may not be what we want in all cases, but will do for now
0014    InstanceType = object
0015
0016
0017try:
0018    from mx.DateTime import DateTimeType, DateTimeDeltaType
0019except ImportError:
0020    try:
0021        from DateTime import DateTimeType, DateTimeDeltaType
0022    except ImportError:
0023        DateTimeType = None
0024        DateTimeDeltaType = None
0025
0026try:
0027    import Sybase
0028    NumericType = Sybase.NumericType
0029except ImportError:
0030    NumericType = None
0031
0032
0033########################################
0034# Quoting
0035########################################
0036
0037sqlStringReplace = [
0038    ("'", "''"),
0039    ('\\', '\\\\'),
0040    ('\000', '\\0'),
0041    ('\b', '\\b'),
0042    ('\n', '\\n'),
0043    ('\r', '\\r'),
0044    ('\t', '\\t'),
0045]
0046
0047
0048class ConverterRegistry:
0049
0050    def __init__(self):
0051        self.basic = {}
0052        self.klass = {}
0053
0054    def registerConverter(self, typ, func):
0055        if type(typ) is ClassType:
0056            self.klass[typ] = func
0057        else:
0058            self.basic[typ] = func
0059
0060    if sys.version_info[0] < 3:
0061        def lookupConverter(self, value, default=None):
0062            if type(value) is InstanceType:
0063                # lookup on klasses dict
0064                return self.klass.get(value.__class__, default)
0065            return self.basic.get(type(value), default)
0066    else:
0067        def lookupConverter(self, value, default=None):
0068            # python 3 doesn't have classic classes, so everything's
0069            # in self.klass due to comparison order in registerConvertor
0070            return self.klass.get(value.__class__, default)
0071
0072converters = ConverterRegistry()
0073registerConverter = converters.registerConverter
0074lookupConverter = converters.lookupConverter
0075
0076
0077def StringLikeConverter(value, db):
0078    if isinstance(value, array):
0079        try:
0080            value = value.tounicode()
0081        except ValueError:
0082            value = value.tostring()
0083    elif isinstance(value, buffer_type):
0084        value = str(value)
0085
0086    if db in ('mysql', 'postgres', 'rdbhost'):
0087        for orig, repl in sqlStringReplace:
0088            value = value.replace(orig, repl)
0089    elif db in ('sqlite', 'firebird', 'sybase', 'maxdb', 'mssql'):
0090        value = value.replace("'", "''")
0091    else:
0092        assert 0, "Database %s unknown" % db
0093    if db in ('postgres', 'rdbhost') and ('\\' in value):
0094        return "E'%s'" % value
0095    return "'%s'" % value
0096
0097registerConverter(str, StringLikeConverter)
0098if sys.version_info[0] < 3:
0099    # noqa for flake8 & python3
0100    registerConverter(unicode, StringLikeConverter)  # noqa
0101registerConverter(array, StringLikeConverter)
0102if sys.version_info[0] < 3:
0103    registerConverter(buffer_type, StringLikeConverter)
0104else:
0105    registerConverter(memoryview, StringLikeConverter)
0106
0107
0108def IntConverter(value, db):
0109    return repr(int(value))
0110
0111registerConverter(int, IntConverter)
0112
0113
0114def LongConverter(value, db):
0115    return str(value)
0116
0117if sys.version_info[0] < 3:
0118    # noqa for flake8 & python3
0119    registerConverter(long, LongConverter)  # noqa
0120
0121if NumericType:
0122    registerConverter(NumericType, IntConverter)
0123
0124
0125def BoolConverter(value, db):
0126    if db in ('postgres', 'rdbhost'):
0127        if value:
0128            return "'t'"
0129        else:
0130            return "'f'"
0131    else:
0132        if value:
0133            return '1'
0134        else:
0135            return '0'
0136
0137registerConverter(bool, BoolConverter)
0138
0139
0140def FloatConverter(value, db):
0141    return repr(value)
0142
0143registerConverter(float, FloatConverter)
0144
0145if DateTimeType:
0146    def DateTimeConverter(value, db):
0147        return "'%s'" % value.strftime("%Y-%m-%d %H:%M:%S.%s")
0148
0149    registerConverter(DateTimeType, DateTimeConverter)
0150
0151    def TimeConverter(value, db):
0152        return "'%s'" % value.strftime("%H:%M:%S")
0153
0154    registerConverter(DateTimeDeltaType, TimeConverter)
0155
0156
0157def NoneConverter(value, db):
0158    return "NULL"
0159
0160registerConverter(NoneType, NoneConverter)
0161
0162
0163def SequenceConverter(value, db):
0164    return "(%s)" % ", ".join([sqlrepr(v, db) for v in value])
0165
0166registerConverter(tuple, SequenceConverter)
0167registerConverter(list, SequenceConverter)
0168registerConverter(dict, SequenceConverter)
0169registerConverter(set, SequenceConverter)
0170registerConverter(frozenset, SequenceConverter)
0171
0172if hasattr(time, 'struct_time'):
0173    def StructTimeConverter(value, db):
0174        return time.strftime("'%Y-%m-%d %H:%M:%S'", value)
0175
0176    registerConverter(time.struct_time, StructTimeConverter)
0177
0178
0179def DateTimeConverter(value, db):
0180    return "'%04d-%02d-%02d %02d:%02d:%02d.%06d'" % (
0181        value.year, value.month, value.day,
0182        value.hour, value.minute, value.second, value.microsecond)
0183
0184registerConverter(datetime.datetime, DateTimeConverter)
0185
0186
0187def DateConverter(value, db):
0188    return "'%04d-%02d-%02d'" % (value.year, value.month, value.day)
0189
0190registerConverter(datetime.date, DateConverter)
0191
0192
0193def TimeConverter(value, db):
0194    return "'%02d:%02d:%02d.%06d'" % (value.hour, value.minute,
0195                                      value.second, value.microsecond)
0196
0197registerConverter(datetime.time, TimeConverter)
0198
0199
0200def DecimalConverter(value, db):
0201    return value.to_eng_string()
0202
0203registerConverter(Decimal, DecimalConverter)
0204
0205
0206def TimedeltaConverter(value, db):
0207
0208    return """INTERVAL '%d days %d seconds'""" %           (value.days, value.seconds)
0210
0211registerConverter(datetime.timedelta, TimedeltaConverter)
0212
0213
0214def sqlrepr(obj, db=None):
0215    try:
0216        reprFunc = obj.__sqlrepr__
0217    except AttributeError:
0218        converter = lookupConverter(obj)
0219        if converter is None:
0220            raise ValueError("Unknown SQL builtin type: %s for %s" %
0221                             (type(obj), repr(obj)))
0222        return converter(obj, db)
0223    else:
0224        return reprFunc(db)
0225
0226
0227def quote_str(s, db):
0228    if db in ('postgres', 'rdbhost') and ('\\' in s):
0229        return "E'%s'" % s
0230    return "'%s'" % s
0231
0232
0233def unquote_str(s):
0234    if s[:2].upper().startswith("E'") and s.endswith("'"):
0235        return s[2:-1]
0236    elif s.startswith("'") and s.endswith("'"):
0237        return s[1:-1]
0238    else:
0239        return s