0001from array import array
0002import datetime
0003from decimal import Decimal
0004import time
0005import sys
0006from .compat import PY2, buffer_type
0007
0008if PY2:
0009    from types import ClassType, InstanceType, NoneType
0010else:
0011    # Use suitable aliases for now
0012    ClassType = type
0013    NoneType = type(None)
0014    # This is may not be what we want in all cases, but will do for now
0015    InstanceType = object
0016
0017
0018try:
0019    from mx.DateTime import DateTimeType, DateTimeDeltaType
0020except ImportError:
0021    try:
0022        from DateTime import DateTimeType, DateTimeDeltaType
0023    except ImportError:
0024        DateTimeType = None
0025        DateTimeDeltaType = None
0026
0027try:
0028    import Sybase
0029    NumericType = Sybase.NumericType
0030except ImportError:
0031    NumericType = None
0032
0033
0034########################################
0035# Quoting
0036########################################
0037
0038sqlStringReplace = [
0039    ("'", "''"),
0040    ('\\', '\\\\'),
0041    ('\000', '\\0'),
0042    ('\b', '\\b'),
0043    ('\n', '\\n'),
0044    ('\r', '\\r'),
0045    ('\t', '\\t'),
0046]
0047
0048
0049class ConverterRegistry:
0050
0051    def __init__(self):
0052        self.basic = {}
0053        self.klass = {}
0054
0055    def registerConverter(self, typ, func):
0056        if type(typ) is ClassType:
0057            self.klass[typ] = func
0058        else:
0059            self.basic[typ] = func
0060
0061    if PY2:
0062        def lookupConverter(self, value, default=None):
0063            if type(value) is InstanceType:
0064                # lookup on klasses dict
0065                return self.klass.get(value.__class__, default)
0066            return self.basic.get(type(value), default)
0067    else:
0068        def lookupConverter(self, value, default=None):
0069            # python 3 doesn't have classic classes, so everything's
0070            # in self.klass due to comparison order in registerConvertor
0071            return self.klass.get(value.__class__, default)
0072
0073converters = ConverterRegistry()
0074registerConverter = converters.registerConverter
0075lookupConverter = converters.lookupConverter
0076
0077
0078def StringLikeConverter(value, db):
0079    if isinstance(value, array):
0080        try:
0081            value = value.tounicode()
0082        except ValueError:
0083            value = value.tostring()
0084    elif isinstance(value, buffer_type):
0085        value = str(value)
0086
0087    if db in ('mysql', 'postgres', 'rdbhost'):
0088        for orig, repl in sqlStringReplace:
0089            value = value.replace(orig, repl)
0090    elif db in ('sqlite', 'firebird', 'sybase', 'maxdb', 'mssql'):
0091        value = value.replace("'", "''")
0092    else:
0093        assert 0, "Database %s unknown" % db
0094    if db in ('postgres', 'rdbhost') and ('\\' in value):
0095        return "E'%s'" % value
0096    return "'%s'" % value
0097
0098registerConverter(str, StringLikeConverter)
0099if PY2:
0100    # noqa for flake8 & python3
0101    registerConverter(unicode, StringLikeConverter)  # noqa
0102registerConverter(array, StringLikeConverter)
0103if PY2:
0104    registerConverter(buffer_type, StringLikeConverter)
0105else:
0106    registerConverter(memoryview, StringLikeConverter)
0107
0108
0109def IntConverter(value, db):
0110    return repr(int(value))
0111
0112registerConverter(int, IntConverter)
0113
0114
0115def LongConverter(value, db):
0116    return str(value)
0117
0118if sys.version_info[0] < 3:
0119    # noqa for flake8 & python3
0120    registerConverter(long, LongConverter)  # noqa
0121
0122if NumericType:
0123    registerConverter(NumericType, IntConverter)
0124
0125
0126def BoolConverter(value, db):
0127    if db in ('postgres', 'rdbhost'):
0128        if value:
0129            return "'t'"
0130        else:
0131            return "'f'"
0132    else:
0133        if value:
0134            return '1'
0135        else:
0136            return '0'
0137
0138registerConverter(bool, BoolConverter)
0139
0140
0141def FloatConverter(value, db):
0142    return repr(value)
0143
0144registerConverter(float, FloatConverter)
0145
0146if DateTimeType:
0147    def DateTimeConverter(value, db):
0148        return "'%s'" % value.strftime("%Y-%m-%d %H:%M:%S.%s")
0149
0150    registerConverter(DateTimeType, DateTimeConverter)
0151
0152    def TimeConverter(value, db):
0153        return "'%s'" % value.strftime("%H:%M:%S")
0154
0155    registerConverter(DateTimeDeltaType, TimeConverter)
0156
0157
0158def NoneConverter(value, db):
0159    return "NULL"
0160
0161registerConverter(NoneType, NoneConverter)
0162
0163
0164def SequenceConverter(value, db):
0165    return "(%s)" % ", ".join([sqlrepr(v, db) for v in value])
0166
0167registerConverter(tuple, SequenceConverter)
0168registerConverter(list, SequenceConverter)
0169registerConverter(dict, SequenceConverter)
0170registerConverter(set, SequenceConverter)
0171registerConverter(frozenset, SequenceConverter)
0172
0173if hasattr(time, 'struct_time'):
0174    def StructTimeConverter(value, db):
0175        return time.strftime("'%Y-%m-%d %H:%M:%S'", value)
0176
0177    registerConverter(time.struct_time, StructTimeConverter)
0178
0179
0180def DateTimeConverter(value, db):
0181    return "'%04d-%02d-%02d %02d:%02d:%02d.%06d'" % (
0182        value.year, value.month, value.day,
0183        value.hour, value.minute, value.second, value.microsecond)
0184
0185registerConverter(datetime.datetime, DateTimeConverter)
0186
0187
0188def DateConverter(value, db):
0189    return "'%04d-%02d-%02d'" % (value.year, value.month, value.day)
0190
0191registerConverter(datetime.date, DateConverter)
0192
0193
0194def TimeConverter(value, db):
0195    return "'%02d:%02d:%02d.%06d'" % (value.hour, value.minute,
0196                                      value.second, value.microsecond)
0197
0198registerConverter(datetime.time, TimeConverter)
0199
0200
0201def DecimalConverter(value, db):
0202    return value.to_eng_string()
0203
0204registerConverter(Decimal, DecimalConverter)
0205
0206
0207def TimedeltaConverter(value, db):
0208
0209    return """INTERVAL '%d days %d seconds'""" %           (value.days, value.seconds)
0211
0212registerConverter(datetime.timedelta, TimedeltaConverter)
0213
0214
0215def sqlrepr(obj, db=None):
0216    try:
0217        reprFunc = obj.__sqlrepr__
0218    except AttributeError:
0219        converter = lookupConverter(obj)
0220        if converter is None:
0221            raise ValueError("Unknown SQL builtin type: %s for %s" %
0222                             (type(obj), repr(obj)))
0223        return converter(obj, db)
0224    else:
0225        return reprFunc(db)
0226
0227
0228def quote_str(s, db):
0229    if db in ('postgres', 'rdbhost') and ('\\' in s):
0230        return "E'%s'" % s
0231    return "'%s'" % s
0232
0233
0234def unquote_str(s):
0235    if s[:2].upper().startswith("E'") and s.endswith("'"):
0236        return s[2:-1]
0237    elif s.startswith("'") and s.endswith("'"):
0238        return s[1:-1]
0239    else:
0240        return s