0001"""
0002Col -- SQLObject columns
0003
0004Note that each column object is named BlahBlahCol, and these are used
0005in class definitions.  But there's also a corresponding SOBlahBlahCol
0006object, which is used in SQLObject *classes*.
0007
0008An explanation: when a SQLObject subclass is created, the metaclass
0009looks through your class definition for any subclasses of Col.  It
0010collects them together, and indexes them to do all the database stuff
0011you like, like the magic attributes and whatnot.  It then asks the Col
0012object to create an SOCol object (usually a subclass, actually).  The
0013SOCol object contains all the interesting logic, as well as a record
0014of the attribute name you used and the class it is bound to (set by
0015the metaclass).
0016
0017So, in summary: Col objects are what you define, but SOCol objects
0018are what gets used.
0019"""
0020
0021from array import array
0022from decimal import Decimal
0023from itertools import count
0024import time
0025from uuid import UUID
0026try:
0027    import cPickle as pickle
0028except ImportError:
0029    import pickle
0030import weakref
0031
0032from formencode import compound, validators
0033from .classregistry import findClass
0034# Sadly the name "constraints" conflicts with many of the function
0035# arguments in this module, so we rename it:
0036from . import constraints as constrs
0037from . import converters
0038from . import sqlbuilder
0039from .styles import capword
0040from .compat import PY2, string_type, unicode_type, buffer_type
0041
0042import datetime
0043datetime_available = True
0044
0045try:
0046    from mx import DateTime
0047except ImportError:
0048    try:
0049        # old version of mxDateTime,
0050        # or Zope's Version if we're running with Zope
0051        import DateTime
0052    except ImportError:
0053        mxdatetime_available = False
0054    else:
0055        mxdatetime_available = True
0056else:
0057    mxdatetime_available = True
0058
0059DATETIME_IMPLEMENTATION = "datetime"
0060MXDATETIME_IMPLEMENTATION = "mxDateTime"
0061
0062if mxdatetime_available:
0063    if hasattr(DateTime, "Time"):
0064        DateTimeType = type(DateTime.now())
0065        TimeType = type(DateTime.Time())
0066    else:  # Zope
0067        DateTimeType = type(DateTime.DateTime())
0068        TimeType = type(DateTime.DateTime.Time(DateTime.DateTime()))
0069
0070__all__ = ["datetime_available", "mxdatetime_available",
0071           "default_datetime_implementation", "DATETIME_IMPLEMENTATION"]
0072
0073if mxdatetime_available:
0074    __all__.append("MXDATETIME_IMPLEMENTATION")
0075
0076default_datetime_implementation = DATETIME_IMPLEMENTATION
0077
0078if not PY2:
0079    # alias for python 3 compatibility
0080    long = int
0081    # This is to satisfy flake8 under python 3
0082    unicode = str
0083
0084NoDefault = sqlbuilder.NoDefault
0085
0086
0087def use_microseconds(use=True):
0088    if use:
0089        SODateTimeCol.datetimeFormat = '%Y-%m-%d %H:%M:%S.%f'
0090        SOTimeCol.timeFormat = '%H:%M:%S.%f'
0091        dt_types = [(datetime.datetime, converters.DateTimeConverterMS),
0092                    (datetime.time, converters.TimeConverterMS)]
0093    else:
0094        SODateTimeCol.datetimeFormat = '%Y-%m-%d %H:%M:%S'
0095        SOTimeCol.timeFormat = '%H:%M:%S'
0096        dt_types = [(datetime.datetime, converters.DateTimeConverter),
0097                    (datetime.time, converters.TimeConverter)]
0098    for dt_type, converter in dt_types:
0099        converters.registerConverter(dt_type, converter)
0100
0101
0102__all__.append("use_microseconds")
0103
0104
0105creationOrder = count()
0106
0107########################################
0108# Columns
0109########################################
0110
0111# Col is essentially a column definition, it doesn't have much logic to it.
0112
0113
0114class SOCol(object):
0115
0116    def __init__(self,
0117                 name,
0118                 soClass,
0119                 creationOrder,
0120                 dbName=None,
0121                 default=NoDefault,
0122                 defaultSQL=None,
0123                 foreignKey=None,
0124                 alternateID=False,
0125                 alternateMethodName=None,
0126                 constraints=None,
0127                 notNull=NoDefault,
0128                 notNone=NoDefault,
0129                 unique=NoDefault,
0130                 sqlType=None,
0131                 columnDef=None,
0132                 validator=None,
0133                 validator2=None,
0134                 immutable=False,
0135                 cascade=None,
0136                 lazy=False,
0137                 noCache=False,
0138                 forceDBName=False,
0139                 title=None,
0140                 tags=[],
0141                 origName=None,
0142                 dbEncoding=None,
0143                 extra_vars=None):
0144
0145        super(SOCol, self).__init__()
0146
0147        # This isn't strictly true, since we *could* use backquotes or
0148        # " or something (database-specific) around column names, but
0149        # why would anyone *want* to use a name like that?
0150        # @@: I suppose we could actually add backquotes to the
0151        # dbName if we needed to...
0152        if not forceDBName:
0153            assert sqlbuilder.sqlIdentifier(name), (
0154                'Name must be SQL-safe '
0155                '(letters, numbers, underscores): %s (or use forceDBName=True)'
0156                % repr(name))
0157        assert name != 'id', (
0158            'The column name "id" is reserved for SQLObject use '
0159            '(and is implicitly created).')
0160        assert name, "You must provide a name for all columns"
0161
0162        self.columnDef = columnDef
0163        self.creationOrder = creationOrder
0164
0165        self.immutable = immutable
0166
0167        # cascade can be one of:
0168        # None: no constraint is generated
0169        # True: a CASCADE constraint is generated
0170        # False: a RESTRICT constraint is generated
0171        # 'null': a SET NULL trigger is generated
0172        if isinstance(cascade, str):
0173            assert cascade == 'null', (
0174                "The only string value allowed for cascade is 'null' "
0175                "(you gave: %r)" % cascade)
0176        self.cascade = cascade
0177
0178        if not isinstance(constraints, (list, tuple)):
0179            constraints = [constraints]
0180        self.constraints = self.autoConstraints() + constraints
0181
0182        self.notNone = False
0183        if notNull is not NoDefault:
0184            self.notNone = notNull
0185            assert notNone is NoDefault or (not notNone) == (not notNull), (
0186                "The notNull and notNone arguments are aliases, "
0187                "and must not conflict.  "
0188                "You gave notNull=%r, notNone=%r" % (notNull, notNone))
0189        elif notNone is not NoDefault:
0190            self.notNone = notNone
0191        if self.notNone:
0192            self.constraints = [constrs.notNull] + self.constraints
0193
0194        self.name = name
0195        self.soClass = soClass
0196        self._default = default
0197        self.defaultSQL = defaultSQL
0198        self.customSQLType = sqlType
0199
0200        # deal with foreign keys
0201        self.foreignKey = foreignKey
0202        if self.foreignKey:
0203            if origName is not None:
0204                idname = soClass.sqlmeta.style.instanceAttrToIDAttr(origName)
0205            else:
0206                idname = soClass.sqlmeta.style.instanceAttrToIDAttr(name)
0207            if self.name != idname:
0208                self.foreignName = self.name
0209                self.name = idname
0210            else:
0211                self.foreignName = soClass.sqlmeta.style.                      instanceIDAttrToAttr(self.name)
0213        else:
0214            self.foreignName = None
0215
0216        # if they don't give us a specific database name for
0217        # the column, we separate the mixedCase into mixed_case
0218        # and assume that.
0219        if dbName is None:
0220            self.dbName = soClass.sqlmeta.style.pythonAttrToDBColumn(self.name)
0221        else:
0222            self.dbName = dbName
0223
0224        # alternateID means that this is a unique column that
0225        # can be used to identify rows
0226        self.alternateID = alternateID
0227
0228        if unique is NoDefault:
0229            self.unique = alternateID
0230        else:
0231            self.unique = unique
0232        if self.unique and alternateMethodName is None:
0233            self.alternateMethodName = 'by' + capword(self.name)
0234        else:
0235            self.alternateMethodName = alternateMethodName
0236
0237        _validators = self.createValidators()
0238        if validator:
0239            _validators.append(validator)
0240        if validator2:
0241            _validators.insert(0, validator2)
0242        _vlen = len(_validators)
0243        if _vlen:
0244            for _validator in _validators:
0245                _validator.soCol = weakref.proxy(self)
0246        if _vlen == 0:
0247            self.validator = None  # Set sef.{from,to}_python
0248        elif _vlen == 1:
0249            self.validator = _validators[0]
0250        elif _vlen > 1:
0251            self.validator = compound.All.join(
0252                _validators[0], *_validators[1:])
0253        self.noCache = noCache
0254        self.lazy = lazy
0255        # this is in case of ForeignKey, where we rename the column
0256        # and append an ID
0257        self.origName = origName or name
0258        self.title = title
0259        self.tags = tags
0260        self.dbEncoding = dbEncoding
0261
0262        if extra_vars:
0263            for name, value in extra_vars.items():
0264                setattr(self, name, value)
0265
0266    def _set_validator(self, value):
0267        self._validator = value
0268        if self._validator:
0269            self.to_python = self._validator.to_python
0270            self.from_python = self._validator.from_python
0271        else:
0272            self.to_python = None
0273            self.from_python = None
0274
0275    def _get_validator(self):
0276        return self._validator
0277
0278    validator = property(_get_validator, _set_validator)
0279
0280    def createValidators(self):
0281        """Create a list of validators for the column."""
0282        return []
0283
0284    def autoConstraints(self):
0285        return []
0286
0287    def _get_default(self):
0288        # A default can be a callback or a plain value,
0289        # here we resolve the callback
0290        if self._default is NoDefault:
0291            return NoDefault
0292        elif hasattr(self._default, '__sqlrepr__'):
0293            return self._default
0294        elif callable(self._default):
0295            return self._default()
0296        else:
0297            return self._default
0298    default = property(_get_default, None, None)
0299
0300    def _get_joinName(self):
0301        return self.soClass.sqlmeta.style.instanceIDAttrToAttr(self.name)
0302    joinName = property(_get_joinName, None, None)
0303
0304    def __repr__(self):
0305        r = '<%s %s' % (self.__class__.__name__, self.name)
0306        if self.default is not NoDefault:
0307            r += ' default=%s' % repr(self.default)
0308        if self.foreignKey:
0309            r += ' connected to %s' % self.foreignKey
0310        if self.alternateID:
0311            r += ' alternate ID'
0312        if self.notNone:
0313            r += ' not null'
0314        return r + '>'
0315
0316    def createSQL(self):
0317        return ' '.join([self._sqlType()] + self._extraSQL())
0318
0319    def _extraSQL(self):
0320        result = []
0321        if self.notNone or self.alternateID:
0322            result.append('NOT NULL')
0323        if self.unique or self.alternateID:
0324            result.append('UNIQUE')
0325        if self.defaultSQL is not None:
0326            result.append("DEFAULT %s" % self.defaultSQL)
0327        return result
0328
0329    def _sqlType(self):
0330        if self.customSQLType is None:
0331            raise ValueError("Col %s (%s) cannot be used for automatic "
0332                             "schema creation (too abstract)" %
0333                             (self.name, self.__class__))
0334        else:
0335            return self.customSQLType
0336
0337    def _mysqlType(self):
0338        return self._sqlType()
0339
0340    def _postgresType(self):
0341        return self._sqlType()
0342
0343    def _sqliteType(self):
0344        # SQLite is naturally typeless, so as a fallback it uses
0345        # no type.
0346        try:
0347            return self._sqlType()
0348        except ValueError:
0349            return ''
0350
0351    def _sybaseType(self):
0352        return self._sqlType()
0353
0354    def _mssqlType(self):
0355        return self._sqlType()
0356
0357    def _firebirdType(self):
0358        return self._sqlType()
0359
0360    def _maxdbType(self):
0361        return self._sqlType()
0362
0363    def mysqlCreateSQL(self, connection=None):
0364        self.connection = connection
0365        return ' '.join([self.dbName, self._mysqlType()] + self._extraSQL())
0366
0367    def postgresCreateSQL(self):
0368        return ' '.join([self.dbName, self._postgresType()] + self._extraSQL())
0369
0370    def sqliteCreateSQL(self):
0371        return ' '.join([self.dbName, self._sqliteType()] + self._extraSQL())
0372
0373    def sybaseCreateSQL(self):
0374        return ' '.join([self.dbName, self._sybaseType()] + self._extraSQL())
0375
0376    def mssqlCreateSQL(self, connection=None):
0377        self.connection = connection
0378        return ' '.join([self.dbName, self._mssqlType()] + self._extraSQL())
0379
0380    def firebirdCreateSQL(self):
0381        # Ian Sparks pointed out that fb is picky about the order
0382        # of the NOT NULL clause in a create statement.  So, we handle
0383        # them differently for Enum columns.
0384        if not isinstance(self, SOEnumCol):
0385            return ' '.join(
0386                [self.dbName, self._firebirdType()] + self._extraSQL())
0387        else:
0388            return ' '.join(
0389                [self.dbName] + [self._firebirdType()[0]] +
0390                self._extraSQL() + [self._firebirdType()[1]])
0391
0392    def maxdbCreateSQL(self):
0393        return ' '.join([self.dbName, self._maxdbType()] + self._extraSQL())
0394
0395    def __get__(self, obj, type=None):
0396        if obj is None:
0397            # class attribute, return the descriptor itself
0398            return self
0399        if obj.sqlmeta._obsolete:
0400            raise RuntimeError('The object <%s %s> is obsolete' % (
0401                obj.__class__.__name__, obj.id))
0402        if obj.sqlmeta.cacheColumns:
0403            columns = obj.sqlmeta._columnCache
0404            if columns is None:
0405                obj.sqlmeta.loadValues()
0406            try:
0407                return columns[name]  # noqa
0408            except KeyError:
0409                return obj.sqlmeta.loadColumn(self)
0410        else:
0411            return obj.sqlmeta.loadColumn(self)
0412
0413    def __set__(self, obj, value):
0414        if self.immutable:
0415            raise AttributeError("The column %s.%s is immutable" %
0416                                 (obj.__class__.__name__,
0417                                  self.name))
0418        obj.sqlmeta.setColumn(self, value)
0419
0420    def __delete__(self, obj):
0421        raise AttributeError("I can't be deleted from %r" % obj)
0422
0423    def getDbEncoding(self, state, default='utf-8'):
0424        if self.dbEncoding:
0425            return self.dbEncoding
0426        dbEncoding = state.soObject.sqlmeta.dbEncoding
0427        if dbEncoding:
0428            return dbEncoding
0429        try:
0430            connection = state.connection or state.soObject._connection
0431        except AttributeError:
0432            dbEncoding = None
0433        else:
0434            dbEncoding = getattr(connection, "dbEncoding", None)
0435        if not dbEncoding:
0436            dbEncoding = default
0437        return dbEncoding
0438
0439
0440class Col(object):
0441
0442    baseClass = SOCol
0443
0444    def __init__(self, name=None, **kw):
0445        super(Col, self).__init__()
0446        self.__dict__['_name'] = name
0447        self.__dict__['_kw'] = kw
0448        self.__dict__['creationOrder'] = next(creationOrder)
0449        self.__dict__['_extra_vars'] = {}
0450
0451    def _set_name(self, value):
0452        assert self._name is None or self._name == value, (
0453            "You cannot change a name after it has already been set "
0454            "(from %s to %s)" % (self.name, value))
0455        self.__dict__['_name'] = value
0456
0457    def _get_name(self):
0458        return self._name
0459
0460    name = property(_get_name, _set_name)
0461
0462    def withClass(self, soClass):
0463        return self.baseClass(soClass=soClass, name=self._name,
0464                              creationOrder=self.creationOrder,
0465                              columnDef=self,
0466                              extra_vars=self._extra_vars,
0467                              **self._kw)
0468
0469    def __setattr__(self, var, value):
0470        if var == 'name':
0471            super(Col, self).__setattr__(var, value)
0472            return
0473        self._extra_vars[var] = value
0474
0475    def __repr__(self):
0476        return '<%s %s %s>' % (
0477            self.__class__.__name__, hex(abs(id(self)))[2:],
0478            self._name or '(unnamed)')
0479
0480
0481class SOValidator(validators.Validator):
0482    def getDbEncoding(self, state, default='utf-8'):
0483        try:
0484            return self.dbEncoding
0485        except AttributeError:
0486            return self.soCol.getDbEncoding(state, default=default)
0487
0488
0489class SOStringLikeCol(SOCol):
0490    """A common ancestor for SOStringCol and SOUnicodeCol"""
0491    def __init__(self, **kw):
0492        self.length = kw.pop('length', None)
0493        self.varchar = kw.pop('varchar', 'auto')
0494        self.char_binary = kw.pop('char_binary', None)  # A hack for MySQL
0495        if not self.length:
0496            assert self.varchar == 'auto' or not self.varchar,                   "Without a length strings are treated as TEXT, not varchar"
0498            self.varchar = False
0499        elif self.varchar == 'auto':
0500            self.varchar = True
0501
0502        super(SOStringLikeCol, self).__init__(**kw)
0503
0504    def autoConstraints(self):
0505        constraints = [constrs.isString]
0506        if self.length is not None:
0507            constraints += [constrs.MaxLength(self.length)]
0508        return constraints
0509
0510    def _sqlType(self):
0511        if self.customSQLType is not None:
0512            return self.customSQLType
0513        if not self.length:
0514            return 'TEXT'
0515        elif self.varchar:
0516            return 'VARCHAR(%i)' % self.length
0517        else:
0518            return 'CHAR(%i)' % self.length
0519
0520    def _check_case_sensitive(self, db):
0521        if self.char_binary:
0522            raise ValueError("%s does not support "
0523                             "binary character columns" % db)
0524
0525    def _mysqlType(self):
0526        type = self._sqlType()
0527        if self.char_binary:
0528            type += " BINARY"
0529        return type
0530
0531    def _postgresType(self):
0532        self._check_case_sensitive("PostgreSQL")
0533        return super(SOStringLikeCol, self)._postgresType()
0534
0535    def _sqliteType(self):
0536        self._check_case_sensitive("SQLite")
0537        return super(SOStringLikeCol, self)._sqliteType()
0538
0539    def _sybaseType(self):
0540        self._check_case_sensitive("SYBASE")
0541        type = self._sqlType()
0542        if not self.notNone and not self.alternateID:
0543            type += ' NULL'
0544        return type
0545
0546    def _mssqlType(self):
0547        if self.customSQLType is not None:
0548            return self.customSQLType
0549        if not self.length:
0550            if self.connection and self.connection.can_use_max_types():
0551                type = 'VARCHAR(MAX)'
0552            else:
0553                type = 'VARCHAR(4000)'
0554        elif self.varchar:
0555            type = 'VARCHAR(%i)' % self.length
0556        else:
0557            type = 'CHAR(%i)' % self.length
0558        if not self.notNone and not self.alternateID:
0559            type += ' NULL'
0560        return type
0561
0562    def _firebirdType(self):
0563        self._check_case_sensitive("FireBird")
0564        if not self.length:
0565            return 'BLOB SUB_TYPE TEXT'
0566        else:
0567            return self._sqlType()
0568
0569    def _maxdbType(self):
0570        self._check_case_sensitive("SAP DB/MaxDB")
0571        if not self.length:
0572            return 'LONG ASCII'
0573        else:
0574            return self._sqlType()
0575
0576
0577class StringValidator(SOValidator):
0578
0579    def to_python(self, value, state):
0580        if value is None:
0581            return None
0582        try:
0583            connection = state.connection or state.soObject._connection
0584            binaryType = connection._binaryType
0585            dbName = connection.dbName
0586        except AttributeError:
0587            binaryType = type(None)  # Just a simple workaround
0588        dbEncoding = self.getDbEncoding(state, default='ascii')
0589        if isinstance(value, unicode_type):
0590            if PY2:
0591                return value.encode(dbEncoding)
0592            return value
0593        if self.dataType and isinstance(value, self.dataType):
0594            return value
0595        if isinstance(value,
0596                      (str, buffer_type, binaryType,
0597                       sqlbuilder.SQLExpression)):
0598            return value
0599        if hasattr(value, '__unicode__'):
0600            return unicode(value).encode(dbEncoding)
0601        if dbName == 'mysql' and not PY2 and isinstance(value, bytes):
0602            return value.decode('ascii', errors='surrogateescape')
0603        raise validators.Invalid(
0604            "expected a str in the StringCol '%s', got %s %r instead" % (
0605                self.name, type(value), value), value, state)
0606
0607    from_python = to_python
0608
0609
0610class SOStringCol(SOStringLikeCol):
0611
0612    def createValidators(self, dataType=None):
0613        return [StringValidator(name=self.name, dataType=dataType)] +               super(SOStringCol, self).createValidators()
0615
0616
0617class StringCol(Col):
0618    baseClass = SOStringCol
0619
0620
0621class NQuoted(sqlbuilder.SQLExpression):
0622    def __init__(self, value):
0623        assert isinstance(value, unicode_type)
0624        self.value = value
0625
0626    def __hash__(self):
0627        return hash(self.value)
0628
0629    def __sqlrepr__(self, db):
0630        assert db == 'mssql'
0631        return "N" + sqlbuilder.sqlrepr(self.value, db)
0632
0633
0634class UnicodeStringValidator(SOValidator):
0635
0636    def to_python(self, value, state):
0637        if value is None:
0638            return None
0639        if isinstance(value, (unicode_type, sqlbuilder.SQLExpression)):
0640            return value
0641        if isinstance(value, str):
0642            return value.decode(self.getDbEncoding(state))
0643        if isinstance(value, array):  # MySQL
0644            return value.tostring().decode(self.getDbEncoding(state))
0645        if hasattr(value, '__unicode__'):
0646            return unicode(value)
0647        raise validators.Invalid(
0648            "expected a str or a unicode in the UnicodeCol '%s', "
0649            "got %s %r instead" % (
0650                self.name, type(value), value), value, state)
0651
0652    def from_python(self, value, state):
0653        if value is None:
0654            return None
0655        if isinstance(value, (str, sqlbuilder.SQLExpression)):
0656            return value
0657        if isinstance(value, unicode_type):
0658            try:
0659                connection = state.connection or state.soObject._connection
0660            except AttributeError:
0661                pass
0662            else:
0663                if connection.dbName == 'mssql':
0664                    return NQuoted(value)
0665            return value.encode(self.getDbEncoding(state))
0666        if hasattr(value, '__unicode__'):
0667            return unicode(value).encode(self.getDbEncoding(state))
0668        raise validators.Invalid(
0669            "expected a str or a unicode in the UnicodeCol '%s', "
0670            "got %s %r instead" % (
0671                self.name, type(value), value), value, state)
0672
0673
0674class SOUnicodeCol(SOStringLikeCol):
0675    def _mssqlType(self):
0676        if self.customSQLType is not None:
0677            return self.customSQLType
0678        return 'N' + super(SOUnicodeCol, self)._mssqlType()
0679
0680    def createValidators(self):
0681        return [UnicodeStringValidator(name=self.name)] +               super(SOUnicodeCol, self).createValidators()
0683
0684
0685class UnicodeCol(Col):
0686    baseClass = SOUnicodeCol
0687
0688
0689class UuidValidator(SOValidator):
0690
0691    def to_python(self, value, state):
0692        if value is None:
0693            return None
0694        if isinstance(value, str):
0695            return UUID(value)
0696        raise validators.Invalid(
0697            "expected string in the UuidCol '%s', "
0698            "got %s %r instead" % (
0699                self.name, type(value), value), value, state)
0700
0701    def from_python(self, value, state):
0702        if value is None:
0703            return None
0704        if isinstance(value, UUID):
0705            return str(value)
0706        raise validators.Invalid(
0707            "expected uuid in the UuidCol '%s', "
0708            "got %s %r instead" % (
0709                self.name, type(value), value), value, state)
0710
0711
0712class SOUuidCol(SOCol):
0713    def createValidators(self):
0714        return [UuidValidator(name=self.name)] +               super(SOUuidCol, self).createValidators()
0716
0717    def _sqlType(self):
0718        return 'VARCHAR(36)'
0719
0720    def _postgresType(self):
0721        return 'UUID'
0722
0723
0724class UuidCol(Col):
0725    baseClass = SOUuidCol
0726
0727
0728class IntValidator(SOValidator):
0729
0730    def to_python(self, value, state):
0731        if value is None:
0732            return None
0733        if isinstance(value, (int, long, sqlbuilder.SQLExpression)):
0734            return value
0735        for converter, attr_name in (int, '__int__'), (long, '__long__'):
0736            if hasattr(value, attr_name):
0737                try:
0738                    return converter(value)
0739                except:
0740                    break
0741        raise validators.Invalid(
0742            "expected an int in the IntCol '%s', got %s %r instead" % (
0743                self.name, type(value), value), value, state)
0744
0745    from_python = to_python
0746
0747
0748class SOIntCol(SOCol):
0749    # 3-03 @@: support precision, maybe max and min directly
0750    def __init__(self, **kw):
0751        self.length = kw.pop('length', None)
0752        self.unsigned = bool(kw.pop('unsigned', None))
0753        self.zerofill = bool(kw.pop('zerofill', None))
0754        SOCol.__init__(self, **kw)
0755
0756    def autoConstraints(self):
0757        return [constrs.isInt]
0758
0759    def createValidators(self):
0760        return [IntValidator(name=self.name)] +               super(SOIntCol, self).createValidators()
0762
0763    def addSQLAttrs(self, str):
0764        _ret = str
0765        if str is None or len(str) < 1:
0766            return None
0767
0768        if self.length and self.length >= 1:
0769            _ret = "%s(%d)" % (_ret, self.length)
0770        if self.unsigned:
0771            _ret = _ret + " UNSIGNED"
0772        if self.zerofill:
0773            _ret = _ret + " ZEROFILL"
0774        return _ret
0775
0776    def _sqlType(self):
0777        return self.addSQLAttrs("INT")
0778
0779
0780class IntCol(Col):
0781    baseClass = SOIntCol
0782
0783
0784class SOTinyIntCol(SOIntCol):
0785    def _sqlType(self):
0786        return self.addSQLAttrs("TINYINT")
0787
0788
0789class TinyIntCol(Col):
0790    baseClass = SOTinyIntCol
0791
0792
0793class SOSmallIntCol(SOIntCol):
0794    def _sqlType(self):
0795        return self.addSQLAttrs("SMALLINT")
0796
0797
0798class SmallIntCol(Col):
0799    baseClass = SOSmallIntCol
0800
0801
0802class SOMediumIntCol(SOIntCol):
0803    def _sqlType(self):
0804        return self.addSQLAttrs("MEDIUMINT")
0805
0806
0807class MediumIntCol(Col):
0808    baseClass = SOMediumIntCol
0809
0810
0811class SOBigIntCol(SOIntCol):
0812    def _sqlType(self):
0813        return self.addSQLAttrs("BIGINT")
0814
0815
0816class BigIntCol(Col):
0817    baseClass = SOBigIntCol
0818
0819
0820class BoolValidator(SOValidator):
0821
0822    def to_python(self, value, state):
0823        if value is None:
0824            return None
0825        if isinstance(value, (bool, sqlbuilder.SQLExpression)):
0826            return value
0827        if isinstance(value, (int, long)) or hasattr(value, '__nonzero__'):
0828            return bool(value)
0829        raise validators.Invalid(
0830            "expected a bool or an int in the BoolCol '%s', "
0831            "got %s %r instead" % (
0832                self.name, type(value), value), value, state)
0833
0834    from_python = to_python
0835
0836
0837class SOBoolCol(SOCol):
0838    def autoConstraints(self):
0839        return [constrs.isBool]
0840
0841    def createValidators(self):
0842        return [BoolValidator(name=self.name)] +               super(SOBoolCol, self).createValidators()
0844
0845    def _postgresType(self):
0846        return 'BOOL'
0847
0848    def _mysqlType(self):
0849        return "BOOL"
0850
0851    def _sybaseType(self):
0852        return "BIT"
0853
0854    def _mssqlType(self):
0855        return "BIT"
0856
0857    def _firebirdType(self):
0858        return 'INT'
0859
0860    def _maxdbType(self):
0861        return "BOOLEAN"
0862
0863    def _sqliteType(self):
0864        return "BOOLEAN"
0865
0866
0867class BoolCol(Col):
0868    baseClass = SOBoolCol
0869
0870
0871class FloatValidator(SOValidator):
0872
0873    def to_python(self, value, state):
0874        if value is None:
0875            return None
0876        if isinstance(value, (float, int, long, sqlbuilder.SQLExpression)):
0877            return value
0878        for converter, attr_name in (
0879                (float, '__float__'), (int, '__int__'), (long, '__long__')):
0880            if hasattr(value, attr_name):
0881                try:
0882                    return converter(value)
0883                except:
0884                    break
0885        raise validators.Invalid(
0886            "expected a float in the FloatCol '%s', got %s %r instead" % (
0887                self.name, type(value), value), value, state)
0888
0889    from_python = to_python
0890
0891
0892class SOFloatCol(SOCol):
0893    # 3-03 @@: support precision (e.g., DECIMAL)
0894
0895    def autoConstraints(self):
0896        return [constrs.isFloat]
0897
0898    def createValidators(self):
0899        return [FloatValidator(name=self.name)] +               super(SOFloatCol, self).createValidators()
0901
0902    def _sqlType(self):
0903        return 'FLOAT'
0904
0905    def _mysqlType(self):
0906        return "DOUBLE PRECISION"
0907
0908
0909class FloatCol(Col):
0910    baseClass = SOFloatCol
0911
0912
0913class SOKeyCol(SOCol):
0914    key_type = {int: "INT", str: "TEXT"}
0915
0916    # 3-03 @@: this should have a simplified constructor
0917    # Should provide foreign key information for other DBs.
0918
0919    def __init__(self, **kw):
0920        self.refColumn = kw.pop('refColumn', None)
0921        super(SOKeyCol, self).__init__(**kw)
0922
0923    def _idType(self):
0924        return self.soClass.sqlmeta.idType
0925
0926    def _sqlType(self):
0927        return self.key_type[self._idType()]
0928
0929    def _sybaseType(self):
0930        key_type = {int: "NUMERIC(18,0) NULL", str: "TEXT"}
0931        return key_type[self._idType()]
0932
0933    def _mssqlType(self):
0934        key_type = {int: "INT NULL", str: "TEXT"}
0935        return key_type[self._idType()]
0936
0937
0938class KeyCol(Col):
0939
0940    baseClass = SOKeyCol
0941
0942
0943class ForeignKeyValidator(SOValidator):
0944
0945    def __init__(self, *args, **kw):
0946        super(ForeignKeyValidator, self).__init__(*args, **kw)
0947        self.fkIDType = None
0948
0949    def from_python(self, value, state):
0950        if value is None:
0951            return None
0952        # Avoid importing the main module
0953        # to get the SQLObject class for isinstance
0954        if hasattr(value, 'sqlmeta'):
0955            return value
0956        if self.fkIDType is None:
0957            otherTable = findClass(self.soCol.foreignKey,
0958                                   self.soCol.soClass.sqlmeta.registry)
0959            self.fkIDType = otherTable.sqlmeta.idType
0960        try:
0961            value = self.fkIDType(value)
0962            return value
0963        except (ValueError, TypeError):
0964            pass
0965        raise validators.Invalid("expected a %r for the ForeignKey '%s', "
0966                                 "got %s %r instead" %
0967                                 (self.fkIDType, self.name,
0968                                  type(value), value), value, state)
0969
0970
0971class SOForeignKey(SOKeyCol):
0972
0973    def __init__(self, **kw):
0974        foreignKey = kw['foreignKey']
0975        style = kw['soClass'].sqlmeta.style
0976        if kw.get('name'):
0977            kw['origName'] = kw['name']
0978            kw['name'] = style.instanceAttrToIDAttr(kw['name'])
0979        else:
0980            kw['name'] = style.instanceAttrToIDAttr(
0981                style.pythonClassToAttr(foreignKey))
0982        super(SOForeignKey, self).__init__(**kw)
0983
0984    def createValidators(self):
0985        return [ForeignKeyValidator(name=self.name)] +               super(SOForeignKey, self).createValidators()
0987
0988    def _idType(self):
0989        other = findClass(self.foreignKey, self.soClass.sqlmeta.registry)
0990        return other.sqlmeta.idType
0991
0992    def sqliteCreateSQL(self):
0993        sql = SOKeyCol.sqliteCreateSQL(self)
0994        other = findClass(self.foreignKey, self.soClass.sqlmeta.registry)
0995        tName = other.sqlmeta.table
0996        idName = self.refColumn or other.sqlmeta.idName
0997        if self.cascade is not None:
0998            if self.cascade == 'null':
0999                action = 'ON DELETE SET NULL'
1000            elif self.cascade:
1001                action = 'ON DELETE CASCADE'
1002            else:
1003                action = 'ON DELETE RESTRICT'
1004        else:
1005            action = ''
1006        constraint = ('CONSTRAINT %(colName)s_exists '
1007                      # 'FOREIGN KEY(%(colName)s) '
1008                      'REFERENCES %(tName)s(%(idName)s) '
1009                      '%(action)s' %
1010                      {'tName': tName,
1011                       'colName': self.dbName,
1012                       'idName': idName,
1013                       'action': action})
1014        sql = ' '.join([sql, constraint])
1015        return sql
1016
1017    def postgresCreateSQL(self):
1018        sql = SOKeyCol.postgresCreateSQL(self)
1019        return sql
1020
1021    def postgresCreateReferenceConstraint(self):
1022        sTName = self.soClass.sqlmeta.table
1023        other = findClass(self.foreignKey, self.soClass.sqlmeta.registry)
1024        tName = other.sqlmeta.table
1025        idName = self.refColumn or other.sqlmeta.idName
1026        if self.cascade is not None:
1027            if self.cascade == 'null':
1028                action = 'ON DELETE SET NULL'
1029            elif self.cascade:
1030                action = 'ON DELETE CASCADE'
1031            else:
1032                action = 'ON DELETE RESTRICT'
1033        else:
1034            action = ''
1035        constraint = ('ALTER TABLE %(sTName)s '
1036                      'ADD CONSTRAINT %(colName)s_exists '
1037                      'FOREIGN KEY (%(colName)s) '
1038                      'REFERENCES %(tName)s (%(idName)s) '
1039                      '%(action)s' %
1040                      {'tName': tName,
1041                       'colName': self.dbName,
1042                       'idName': idName,
1043                       'action': action,
1044                       'sTName': sTName})
1045        return constraint
1046
1047    def mysqlCreateReferenceConstraint(self):
1048        sTName = self.soClass.sqlmeta.table
1049        sTLocalName = sTName.split('.')[-1]
1050        other = findClass(self.foreignKey, self.soClass.sqlmeta.registry)
1051        tName = other.sqlmeta.table
1052        idName = self.refColumn or other.sqlmeta.idName
1053        if self.cascade is not None:
1054            if self.cascade == 'null':
1055                action = 'ON DELETE SET NULL'
1056            elif self.cascade:
1057                action = 'ON DELETE CASCADE'
1058            else:
1059                action = 'ON DELETE RESTRICT'
1060        else:
1061            action = ''
1062        constraint = ('ALTER TABLE %(sTName)s '
1063                      'ADD CONSTRAINT %(sTLocalName)s_%(colName)s_exists '
1064                      'FOREIGN KEY (%(colName)s) '
1065                      'REFERENCES %(tName)s (%(idName)s) '
1066                      '%(action)s' %
1067                      {'tName': tName,
1068                       'colName': self.dbName,
1069                       'idName': idName,
1070                       'action': action,
1071                       'sTName': sTName,
1072                       'sTLocalName': sTLocalName})
1073        return constraint
1074
1075    def mysqlCreateSQL(self, connection=None):
1076        return SOKeyCol.mysqlCreateSQL(self, connection)
1077
1078    def sybaseCreateSQL(self):
1079        sql = SOKeyCol.sybaseCreateSQL(self)
1080        other = findClass(self.foreignKey, self.soClass.sqlmeta.registry)
1081        tName = other.sqlmeta.table
1082        idName = self.refColumn or other.sqlmeta.idName
1083        reference = ('REFERENCES %(tName)s(%(idName)s) ' %
1084                     {'tName': tName,
1085                      'idName': idName})
1086        sql = ' '.join([sql, reference])
1087        return sql
1088
1089    def sybaseCreateReferenceConstraint(self):
1090        # @@: Code from above should be moved here
1091        return None
1092
1093    def mssqlCreateSQL(self, connection=None):
1094        sql = SOKeyCol.mssqlCreateSQL(self, connection)
1095        other = findClass(self.foreignKey, self.soClass.sqlmeta.registry)
1096        tName = other.sqlmeta.table
1097        idName = self.refColumn or other.sqlmeta.idName
1098        reference = ('REFERENCES %(tName)s(%(idName)s) ' %
1099                     {'tName': tName,
1100                      'idName': idName})
1101        sql = ' '.join([sql, reference])
1102        return sql
1103
1104    def mssqlCreateReferenceConstraint(self):
1105        # @@: Code from above should be moved here
1106        return None
1107
1108    def maxdbCreateSQL(self):
1109        other = findClass(self.foreignKey, self.soClass.sqlmeta.registry)
1110        fidName = self.dbName
1111        # I assume that foreign key name is identical
1112        # to the id of the reference table
1113        sql = ' '.join([fidName, self._maxdbType()])
1114        tName = other.sqlmeta.table
1115        idName = self.refColumn or other.sqlmeta.idName
1116        sql = sql + ',' + '\n'
1117        sql = sql + 'FOREIGN KEY (%s) REFERENCES %s(%s)' % (fidName, tName,
1118                                                            idName)
1119        return sql
1120
1121    def maxdbCreateReferenceConstraint(self):
1122        # @@: Code from above should be moved here
1123        return None
1124
1125
1126class ForeignKey(KeyCol):
1127
1128    baseClass = SOForeignKey
1129
1130    def __init__(self, foreignKey=None, **kw):
1131        super(ForeignKey, self).__init__(foreignKey=foreignKey, **kw)
1132
1133
1134class EnumValidator(SOValidator):
1135
1136    def to_python(self, value, state):
1137        if value in self.enumValues:
1138            # Only encode on python 2 - on python 3, the database driver
1139            # will handle this
1140            if isinstance(value, unicode_type) and PY2:
1141                dbEncoding = self.getDbEncoding(state)
1142                value = value.encode(dbEncoding)
1143            return value
1144        elif not self.notNone and value is None:
1145            return None
1146        raise validators.Invalid(
1147            "expected a member of %r in the EnumCol '%s', got %r instead" % (
1148                self.enumValues, self.name, value), value, state)
1149
1150    from_python = to_python
1151
1152
1153class SOEnumCol(SOCol):
1154
1155    def __init__(self, **kw):
1156        self.enumValues = kw.pop('enumValues', None)
1157        assert self.enumValues is not None,               'You must provide an enumValues keyword argument'
1159        super(SOEnumCol, self).__init__(**kw)
1160
1161    def autoConstraints(self):
1162        return [constrs.isString, constrs.InList(self.enumValues)]
1163
1164    def createValidators(self):
1165        return [EnumValidator(name=self.name, enumValues=self.enumValues,
1166                              notNone=self.notNone)] +               super(SOEnumCol, self).createValidators()
1168
1169    def _mysqlType(self):
1170        # We need to map None in the enum expression to an appropriate
1171        # condition on NULL
1172        if None in self.enumValues:
1173            return "ENUM(%s)" % ', '.join(
1174                [sqlbuilder.sqlrepr(v, 'mysql') for v in self.enumValues
1175                    if v is not None])
1176        else:
1177            return "ENUM(%s) NOT NULL" % ', '.join(
1178                [sqlbuilder.sqlrepr(v, 'mysql') for v in self.enumValues])
1179
1180    def _postgresType(self):
1181        length = max(map(self._getlength, self.enumValues))
1182        enumValues = ', '.join(
1183            [sqlbuilder.sqlrepr(v, 'postgres') for v in self.enumValues])
1184        checkConstraint = "CHECK (%s in (%s))" % (self.dbName, enumValues)
1185        return "VARCHAR(%i) %s" % (length, checkConstraint)
1186
1187    _sqliteType = _postgresType
1188
1189    def _sybaseType(self):
1190        return self._postgresType()
1191
1192    def _mssqlType(self):
1193        return self._postgresType()
1194
1195    def _firebirdType(self):
1196        length = max(map(self._getlength, self.enumValues))
1197        enumValues = ', '.join(
1198            [sqlbuilder.sqlrepr(v, 'firebird') for v in self.enumValues])
1199        checkConstraint = "CHECK (%s in (%s))" % (self.dbName, enumValues)
1200        # NB. Return a tuple, not a string here
1201        return "VARCHAR(%i)" % (length), checkConstraint
1202
1203    def _maxdbType(self):
1204        raise TypeError("Enum type is not supported on MAX DB")
1205
1206    def _getlength(self, obj):
1207        """
1208        None counts as 0; everything else uses len()
1209        """
1210        if obj is None:
1211            return 0
1212        else:
1213            return len(obj)
1214
1215
1216class EnumCol(Col):
1217    baseClass = SOEnumCol
1218
1219
1220class SetValidator(SOValidator):
1221    """
1222    Translates Python tuples into SQL comma-delimited SET strings.
1223    """
1224
1225    def to_python(self, value, state):
1226        if isinstance(value, str):
1227            return tuple(value.split(","))
1228        raise validators.Invalid(
1229            "expected a string in the SetCol '%s', got %s %r instead" % (
1230                self.name, type(value), value), value, state)
1231
1232    def from_python(self, value, state):
1233        if isinstance(value, string_type):
1234            value = (value,)
1235        try:
1236            return ",".join(value)
1237        except:
1238            raise validators.Invalid(
1239                "expected a string or a sequence of strings "
1240                "in the SetCol '%s', got %s %r instead" % (
1241                    self.name, type(value), value), value, state)
1242
1243
1244class SOSetCol(SOCol):
1245    def __init__(self, **kw):
1246        self.setValues = kw.pop('setValues', None)
1247        assert self.setValues is not None,               'You must provide a setValues keyword argument'
1249        super(SOSetCol, self).__init__(**kw)
1250
1251    def autoConstraints(self):
1252        return [constrs.isString, constrs.InList(self.setValues)]
1253
1254    def createValidators(self):
1255        return [SetValidator(name=self.name, setValues=self.setValues)] +               super(SOSetCol, self).createValidators()
1257
1258    def _mysqlType(self):
1259        return "SET(%s)" % ', '.join(
1260            [sqlbuilder.sqlrepr(v, 'mysql') for v in self.setValues])
1261
1262
1263class SetCol(Col):
1264    baseClass = SOSetCol
1265
1266
1267class DateTimeValidator(validators.DateValidator):
1268    def to_python(self, value, state):
1269        if value is None:
1270            return None
1271        if isinstance(value,
1272                      (datetime.datetime, datetime.date,
1273                       datetime.time, sqlbuilder.SQLExpression)):
1274            return value
1275        if mxdatetime_available:
1276            if isinstance(value, DateTimeType):
1277                # convert mxDateTime instance to datetime
1278                if (self.format.find("%H") >= 0) or                      (self.format.find("%T")) >= 0:
1280                    return datetime.datetime(value.year, value.month,
1281                                             value.day,
1282                                             value.hour, value.minute,
1283                                             int(value.second))
1284                else:
1285                    return datetime.date(value.year, value.month, value.day)
1286            elif isinstance(value, TimeType):
1287                # convert mxTime instance to time
1288                if self.format.find("%d") >= 0:
1289                    return datetime.timedelta(seconds=value.seconds)
1290                else:
1291                    return datetime.time(value.hour, value.minute,
1292                                         int(value.second))
1293        try:
1294            if self.format.find(".%f") >= 0:
1295                if '.' in value:
1296                    _value = value.split('.')
1297                    microseconds = _value[-1]
1298                    _l = len(microseconds)
1299                    if _l < 6:
1300                        _value[-1] = microseconds + '0' * (6 - _l)
1301                    elif _l > 6:
1302                        _value[-1] = microseconds[:6]
1303                    if _l != 6:
1304                        value = '.'.join(_value)
1305                else:
1306                    value += '.0'
1307            return datetime.datetime.strptime(value, self.format)
1308        except:
1309            raise validators.Invalid(
1310                "expected a date/time string of the '%s' format "
1311                "in the DateTimeCol '%s', got %s %r instead" % (
1312                    self.format, self.name, type(value), value), value, state)
1313
1314    def from_python(self, value, state):
1315        if value is None:
1316            return None
1317        if isinstance(value,
1318                      (datetime.datetime, datetime.date,
1319                       datetime.time, sqlbuilder.SQLExpression)):
1320            return value
1321        if hasattr(value, "strftime"):
1322            return value.strftime(self.format)
1323        raise validators.Invalid(
1324            "expected a datetime in the DateTimeCol '%s', "
1325            "got %s %r instead" % (
1326                self.name, type(value), value), value, state)
1327
1328if mxdatetime_available:
1329    class MXDateTimeValidator(validators.DateValidator):
1330        def to_python(self, value, state):
1331            if value is None:
1332                return None
1333            if isinstance(value,
1334                          (DateTimeType, TimeType, sqlbuilder.SQLExpression)):
1335                return value
1336            if isinstance(value, datetime.datetime):
1337                return DateTime.DateTime(value.year, value.month, value.day,
1338                                         value.hour, value.minute,
1339                                         value.second)
1340            elif isinstance(value, datetime.date):
1341                return DateTime.Date(value.year, value.month, value.day)
1342            elif isinstance(value, datetime.time):
1343                return DateTime.Time(value.hour, value.minute, value.second)
1344            try:
1345                if self.format.find(".%f") >= 0:
1346                    if '.' in value:
1347                        _value = value.split('.')
1348                        microseconds = _value[-1]
1349                        _l = len(microseconds)
1350                        if _l < 6:
1351                            _value[-1] = microseconds + '0' * (6 - _l)
1352                        elif _l > 6:
1353                            _value[-1] = microseconds[:6]
1354                        if _l != 6:
1355                            value = '.'.join(_value)
1356                    else:
1357                        value += '.0'
1358                value = datetime.datetime.strptime(value, self.format)
1359                return DateTime.DateTime(value.year, value.month, value.day,
1360                                         value.hour, value.minute,
1361                                         value.second)
1362            except:
1363                raise validators.Invalid(
1364                    "expected a date/time string of the '%s' format "
1365                    "in the DateTimeCol '%s', got %s %r instead" % (
1366                        self.format, self.name, type(value), value),
1367                    value, state)
1368
1369        def from_python(self, value, state):
1370            if value is None:
1371                return None
1372            if isinstance(value,
1373                          (DateTimeType, TimeType, sqlbuilder.SQLExpression)):
1374                return value
1375            if hasattr(value, "strftime"):
1376                return value.strftime(self.format)
1377            raise validators.Invalid(
1378                "expected a mxDateTime in the DateTimeCol '%s', "
1379                "got %s %r instead" % (
1380                    self.name, type(value), value), value, state)
1381
1382
1383class SODateTimeCol(SOCol):
1384    datetimeFormat = '%Y-%m-%d %H:%M:%S.%f'
1385
1386    def __init__(self, **kw):
1387        datetimeFormat = kw.pop('datetimeFormat', None)
1388        if datetimeFormat:
1389            self.datetimeFormat = datetimeFormat
1390        super(SODateTimeCol, self).__init__(**kw)
1391
1392    def createValidators(self):
1393        _validators = super(SODateTimeCol, self).createValidators()
1394        if default_datetime_implementation == DATETIME_IMPLEMENTATION:
1395            validatorClass = DateTimeValidator
1396        elif default_datetime_implementation == MXDATETIME_IMPLEMENTATION:
1397            validatorClass = MXDateTimeValidator
1398        if default_datetime_implementation:
1399            _validators.insert(0, validatorClass(name=self.name,
1400                                                 format=self.datetimeFormat))
1401        return _validators
1402
1403    def _mysqlType(self):
1404        if self.connection and self.connection.can_use_microseconds():
1405            return 'DATETIME(6)'
1406        else:
1407            return 'DATETIME'
1408
1409    def _postgresType(self):
1410        return 'TIMESTAMP'
1411
1412    def _sybaseType(self):
1413        return 'DATETIME'
1414
1415    def _mssqlType(self):
1416        if self.connection and self.connection.can_use_microseconds():
1417            return 'DATETIME2(6)'
1418        else:
1419            return 'DATETIME'
1420
1421    def _sqliteType(self):
1422        return 'TIMESTAMP'
1423
1424    def _firebirdType(self):
1425        return 'TIMESTAMP'
1426
1427    def _maxdbType(self):
1428        return 'TIMESTAMP'
1429
1430
1431class DateTimeCol(Col):
1432    baseClass = SODateTimeCol
1433
1434    @staticmethod
1435    def now():
1436        if default_datetime_implementation == DATETIME_IMPLEMENTATION:
1437            return datetime.datetime.now()
1438        elif default_datetime_implementation == MXDATETIME_IMPLEMENTATION:
1439            return DateTime.now()
1440        else:
1441            assert 0, ("No datetime implementation available "
1442                       "(DATETIME_IMPLEMENTATION=%r)"
1443                       % DATETIME_IMPLEMENTATION)
1444
1445
1446class DateValidator(DateTimeValidator):
1447    def to_python(self, value, state):
1448        if isinstance(value, datetime.datetime):
1449            value = value.date()
1450        if isinstance(value, (datetime.date, sqlbuilder.SQLExpression)):
1451            return value
1452        value = super(DateValidator, self).to_python(value, state)
1453        if isinstance(value, datetime.datetime):
1454            value = value.date()
1455        return value
1456
1457    from_python = to_python
1458
1459
1460class SODateCol(SOCol):
1461    dateFormat = '%Y-%m-%d'
1462
1463    def __init__(self, **kw):
1464        dateFormat = kw.pop('dateFormat', None)
1465        if dateFormat:
1466            self.dateFormat = dateFormat
1467        super(SODateCol, self).__init__(**kw)
1468
1469    def createValidators(self):
1470        """Create a validator for the column.
1471
1472        Can be overriden in descendants.
1473
1474        """
1475        _validators = super(SODateCol, self).createValidators()
1476        if default_datetime_implementation == DATETIME_IMPLEMENTATION:
1477            validatorClass = DateValidator
1478        elif default_datetime_implementation == MXDATETIME_IMPLEMENTATION:
1479            validatorClass = MXDateTimeValidator
1480        if default_datetime_implementation:
1481            _validators.insert(0, validatorClass(name=self.name,
1482                                                 format=self.dateFormat))
1483        return _validators
1484
1485    def _mysqlType(self):
1486        return 'DATE'
1487
1488    def _postgresType(self):
1489        return 'DATE'
1490
1491    def _sybaseType(self):
1492        return self._postgresType()
1493
1494    def _mssqlType(self):
1495        """
1496        SQL Server doesn't have  a DATE data type, to emulate we use a vc(10)
1497        """
1498        return 'VARCHAR(10)'
1499
1500    def _firebirdType(self):
1501        return 'DATE'
1502
1503    def _maxdbType(self):
1504        return 'DATE'
1505
1506    def _sqliteType(self):
1507        return 'DATE'
1508
1509
1510class DateCol(Col):
1511    baseClass = SODateCol
1512
1513
1514class TimeValidator(DateTimeValidator):
1515    def to_python(self, value, state):
1516        if isinstance(value, (datetime.time, sqlbuilder.SQLExpression)):
1517            return value
1518        if isinstance(value, datetime.timedelta):
1519            if value.days:
1520                raise validators.Invalid(
1521                    "the value for the TimeCol '%s' must has days=0, "
1522                    "it has days=%d" % (self.name, value.days), value, state)
1523            return datetime.time(*time.gmtime(value.seconds)[3:6])
1524        value = super(TimeValidator, self).to_python(value, state)
1525        if isinstance(value, datetime.datetime):
1526            value = value.time()
1527        return value
1528
1529    from_python = to_python
1530
1531
1532class SOTimeCol(SOCol):
1533    timeFormat = '%H:%M:%S.%f'
1534
1535    def __init__(self, **kw):
1536        timeFormat = kw.pop('timeFormat', None)
1537        if timeFormat:
1538            self.timeFormat = timeFormat
1539        super(SOTimeCol, self).__init__(**kw)
1540
1541    def createValidators(self):
1542        _validators = super(SOTimeCol, self).createValidators()
1543        if default_datetime_implementation == DATETIME_IMPLEMENTATION:
1544            validatorClass = TimeValidator
1545        elif default_datetime_implementation == MXDATETIME_IMPLEMENTATION:
1546            validatorClass = MXDateTimeValidator
1547        if default_datetime_implementation:
1548            _validators.insert(0, validatorClass(name=self.name,
1549                                                 format=self.timeFormat))
1550        return _validators
1551
1552    def _mysqlType(self):
1553        if self.connection and self.connection.can_use_microseconds():
1554            return 'TIME(6)'
1555        else:
1556            return 'TIME'
1557
1558    def _postgresType(self):
1559        return 'TIME'
1560
1561    def _sybaseType(self):
1562        return 'TIME'
1563
1564    def _mssqlType(self):
1565        if self.connection and self.connection.can_use_microseconds():
1566            return 'TIME(6)'
1567        else:
1568            return 'TIME'
1569
1570    def _sqliteType(self):
1571        return 'TIME'
1572
1573    def _firebirdType(self):
1574        return 'TIME'
1575
1576    def _maxdbType(self):
1577        return 'TIME'
1578
1579
1580class TimeCol(Col):
1581    baseClass = SOTimeCol
1582
1583
1584class SOTimestampCol(SODateTimeCol):
1585    """
1586    Necessary to support MySQL's use of TIMESTAMP versus DATETIME types
1587    """
1588
1589    def __init__(self, **kw):
1590        if 'default' not in kw:
1591            kw['default'] = None
1592        SOCol.__init__(self, **kw)
1593
1594    def _mysqlType(self):
1595        if self.connection and self.connection.can_use_microseconds():
1596            return 'TIMESTAMP(6)'
1597        else:
1598            return 'TIMESTAMP'
1599
1600
1601class TimestampCol(Col):
1602    baseClass = SOTimestampCol
1603
1604
1605class TimedeltaValidator(SOValidator):
1606    def to_python(self, value, state):
1607        return value
1608
1609    from_python = to_python
1610
1611
1612class SOTimedeltaCol(SOCol):
1613    def _postgresType(self):
1614        return 'INTERVAL'
1615
1616    def createValidators(self):
1617        return [TimedeltaValidator(name=self.name)] +               super(SOTimedeltaCol, self).createValidators()
1619
1620
1621class TimedeltaCol(Col):
1622    baseClass = SOTimedeltaCol
1623
1624
1625class DecimalValidator(SOValidator):
1626    def to_python(self, value, state):
1627        if value is None:
1628            return None
1629        if isinstance(value, (int, long, Decimal, sqlbuilder.SQLExpression)):
1630            return value
1631        if isinstance(value, float):
1632            value = str(value)
1633        try:
1634            connection = state.connection or state.soObject._connection
1635        except AttributeError:
1636            pass
1637        else:
1638            if hasattr(connection, "decimalSeparator"):
1639                value = value.replace(connection.decimalSeparator, ".")
1640        try:
1641            return Decimal(value)
1642        except:
1643            raise validators.Invalid(
1644                "expected a Decimal in the DecimalCol '%s', "
1645                "got %s %r instead" % (
1646                    self.name, type(value), value), value, state)
1647
1648    def from_python(self, value, state):
1649        if value is None:
1650            return None
1651        if isinstance(value, float):
1652            value = str(value)
1653        if isinstance(value, string_type):
1654            try:
1655                connection = state.connection or state.soObject._connection
1656            except AttributeError:
1657                pass
1658            else:
1659                if hasattr(connection, "decimalSeparator"):
1660                    value = value.replace(connection.decimalSeparator, ".")
1661            try:
1662                return Decimal(value)
1663            except:
1664                raise validators.Invalid(
1665                    "can not parse Decimal value '%s' "
1666                    "in the DecimalCol from '%s'" % (
1667                        value, getattr(state, 'soObject', '(unknown)')),
1668                    value, state)
1669        if isinstance(value, (int, long, Decimal, sqlbuilder.SQLExpression)):
1670            return value
1671        raise validators.Invalid(
1672            "expected a Decimal in the DecimalCol '%s', got %s %r instead" % (
1673                self.name, type(value), value), value, state)
1674
1675
1676class SODecimalCol(SOCol):
1677
1678    def __init__(self, **kw):
1679        self.size = kw.pop('size', NoDefault)
1680        assert self.size is not NoDefault,               "You must give a size argument"
1682        self.precision = kw.pop('precision', NoDefault)
1683        assert self.precision is not NoDefault,               "You must give a precision argument"
1685        super(SODecimalCol, self).__init__(**kw)
1686
1687    def _sqlType(self):
1688        return 'DECIMAL(%i, %i)' % (self.size, self.precision)
1689
1690    def createValidators(self):
1691        return [DecimalValidator(name=self.name)] +               super(SODecimalCol, self).createValidators()
1693
1694
1695class DecimalCol(Col):
1696    baseClass = SODecimalCol
1697
1698
1699class SOCurrencyCol(SODecimalCol):
1700
1701    def __init__(self, **kw):
1702        pushKey(kw, 'size', 10)
1703        pushKey(kw, 'precision', 2)
1704        super(SOCurrencyCol, self).__init__(**kw)
1705
1706
1707class CurrencyCol(DecimalCol):
1708    baseClass = SOCurrencyCol
1709
1710
1711class DecimalStringValidator(DecimalValidator):
1712    def to_python(self, value, state):
1713        value = super(DecimalStringValidator, self).to_python(value, state)
1714        if self.precision and isinstance(value, Decimal):
1715            assert value < self.max,                   "Value must be less than %s" % int(self.max)
1717            value = value.quantize(self.precision)
1718        return value
1719
1720    def from_python(self, value, state):
1721        value = super(DecimalStringValidator, self).from_python(value, state)
1722        if isinstance(value, Decimal):
1723            if self.precision:
1724                assert value < self.max,                       "Value must be less than %s" % int(self.max)
1726                value = value.quantize(self.precision)
1727            value = value.to_eng_string()
1728        elif isinstance(value, (int, long)):
1729            value = str(value)
1730        return value
1731
1732
1733class SODecimalStringCol(SOStringCol):
1734    def __init__(self, **kw):
1735        self.size = kw.pop('size', NoDefault)
1736        assert (self.size is not NoDefault) and (self.size >= 0),               "You must give a size argument as a positive integer"
1738        self.precision = kw.pop('precision', NoDefault)
1739        assert (self.precision is not NoDefault) and (self.precision >= 0),               "You must give a precision argument as a positive integer"
1741        kw['length'] = int(self.size) + int(self.precision)
1742        self.quantize = kw.pop('quantize', False)
1743        assert isinstance(self.quantize, bool),               "quantize argument must be Boolean True/False"
1745        super(SODecimalStringCol, self).__init__(**kw)
1746
1747    def createValidators(self):
1748        if self.quantize:
1749            v = DecimalStringValidator(
1750                name=self.name,
1751                precision=Decimal(10) ** (-1 * int(self.precision)),
1752                max=Decimal(10) ** (int(self.size) - int(self.precision)))
1753        else:
1754            v = DecimalStringValidator(name=self.name, precision=0)
1755        return [v] +               super(SODecimalStringCol, self).createValidators(dataType=Decimal)
1757
1758
1759class DecimalStringCol(StringCol):
1760    baseClass = SODecimalStringCol
1761
1762
1763class BinaryValidator(SOValidator):
1764    """
1765    Validator for binary types.
1766
1767    We're assuming that the per-database modules provide some form
1768    of wrapper type for binary conversion.
1769    """
1770
1771    _cachedValue = None
1772
1773    def to_python(self, value, state):
1774        if value is None:
1775            return None
1776        try:
1777            connection = state.connection or state.soObject._connection
1778        except AttributeError:
1779            dbName = None
1780            binaryType = type(None)  # Just a simple workaround
1781        else:
1782            dbName = connection.dbName
1783            binaryType = connection._binaryType
1784        if isinstance(value, str):
1785            if dbName == "sqlite":
1786                if not PY2:
1787                    value = bytes(value, 'ascii')
1788                value = connection.module.decode(value)
1789            if dbName == "mysql" and not PY2:
1790                value = value.encode('ascii', errors='surrogateescape')
1791            return value
1792        if isinstance(value, (buffer_type, binaryType)):
1793            cachedValue = self._cachedValue
1794            if cachedValue and cachedValue[1] == value:
1795                return cachedValue[0]
1796            if isinstance(value, array):  # MySQL
1797                return value.tostring()
1798            if not PY2 and isinstance(value, memoryview):
1799                return value.tobytes()
1800            return str(value)  # buffer => string
1801        raise validators.Invalid(
1802            "expected a string in the BLOBCol '%s', got %s %r instead" % (
1803                self.name, type(value), value), value, state)
1804
1805    def from_python(self, value, state):
1806        if value is None:
1807            return None
1808        connection = state.connection or state.soObject._connection
1809        binary = connection.createBinary(value)
1810        if not PY2 and isinstance(binary, memoryview):
1811            binary = str(binary.tobytes(), 'ascii')
1812        self._cachedValue = (value, binary)
1813        return binary
1814
1815
1816class SOBLOBCol(SOStringCol):
1817    def __init__(self, **kw):
1818        # Change the default from 'auto' to False -
1819        # this is a (mostly) binary column
1820        if 'varchar' not in kw:
1821            kw['varchar'] = False
1822        super(SOBLOBCol, self).__init__(**kw)
1823
1824    def createValidators(self):
1825        return [BinaryValidator(name=self.name)] +               super(SOBLOBCol, self).createValidators()
1827
1828    def _mysqlType(self):
1829        length = self.length
1830        varchar = self.varchar
1831        if length:
1832            if length >= 2 ** 24:
1833                return varchar and "LONGTEXT" or "LONGBLOB"
1834            if length >= 2 ** 16:
1835                return varchar and "MEDIUMTEXT" or "MEDIUMBLOB"
1836            if length >= 2 ** 8:
1837                return varchar and "TEXT" or "BLOB"
1838        return varchar and "TINYTEXT" or "TINYBLOB"
1839
1840    def _postgresType(self):
1841        return 'BYTEA'
1842
1843    def _mssqlType(self):
1844        if self.connection and self.connection.can_use_max_types():
1845            return 'VARBINARY(MAX)'
1846        else:
1847            return "IMAGE"
1848
1849
1850class BLOBCol(StringCol):
1851    baseClass = SOBLOBCol
1852
1853
1854class PickleValidator(BinaryValidator):
1855    """
1856    Validator for pickle types.  A pickle type is simply a binary type
1857    with hidden pickling, so that we can simply store any kind of
1858    stuff in a particular column.
1859
1860    The support for this relies directly on the support for binary for
1861    your database.
1862    """
1863
1864    def to_python(self, value, state):
1865        if value is None:
1866            return None
1867        if isinstance(value, unicode_type):
1868            dbEncoding = self.getDbEncoding(state, default='ascii')
1869            value = value.encode(dbEncoding)
1870        if isinstance(value, bytes):
1871            return pickle.loads(value)
1872        raise validators.Invalid(
1873            "expected a pickle string in the PickleCol '%s', "
1874            "got %s %r instead" % (
1875                self.name, type(value), value), value, state)
1876
1877    def from_python(self, value, state):
1878        if value is None:
1879            return None
1880        return pickle.dumps(value, self.pickleProtocol)
1881
1882
1883class SOPickleCol(SOBLOBCol):
1884
1885    def __init__(self, **kw):
1886        self.pickleProtocol = kw.pop('pickleProtocol', pickle.HIGHEST_PROTOCOL)
1887        super(SOPickleCol, self).__init__(**kw)
1888
1889    def createValidators(self):
1890        return [PickleValidator(name=self.name,
1891                pickleProtocol=self.pickleProtocol)] +               super(SOPickleCol, self).createValidators()
1893
1894    def _mysqlType(self):
1895        length = self.length
1896        if length:
1897            if length >= 2 ** 24:
1898                return "LONGBLOB"
1899            if length >= 2 ** 16:
1900                return "MEDIUMBLOB"
1901        return "BLOB"
1902
1903
1904class PickleCol(BLOBCol):
1905    baseClass = SOPickleCol
1906
1907
1908def pushKey(kw, name, value):
1909    if name not in kw:
1910        kw[name] = value
1911
1912all = []
1913# Use copy() to avoid 'dictionary changed' issues on python 3
1914for key, value in globals().copy().items():
1915    if isinstance(value, type) and (issubclass(value, (Col, SOCol))):
1916        all.append(key)
1917__all__.extend(all)
1918del all