0001import sqlbuilder
0002NoDefault = sqlbuilder.NoDefault
0003import styles
0004import classregistry
0005import events
0006
0007__all__ = ['MultipleJoin', 'SQLMultipleJoin', 'RelatedJoin', 'SQLRelatedJoin',
0008           'SingleJoin', 'ManyToMany', 'OneToMany']
0009
0010def getID(obj):
0011    try:
0012        return obj.id
0013    except AttributeError:
0014        return int(obj)
0015
0016class Join(object):
0017
0018    def __init__(self, otherClass=None, **kw):
0019        kw['otherClass'] = otherClass
0020        self.kw = kw
0021        self._joinMethodName = self.kw.pop('joinMethodName', None)
0022
0023    def _set_joinMethodName(self, value):
0024        assert self._joinMethodName == value or self._joinMethodName is None, "You have already given an explicit joinMethodName (%s), and you are now setting it to %s" % (self._joinMethodName, value)
0025        self._joinMethodName = value
0026
0027    def _get_joinMethodName(self):
0028        return self._joinMethodName
0029
0030    joinMethodName = property(_get_joinMethodName, _set_joinMethodName)
0031    name = joinMethodName
0032
0033    def withClass(self, soClass):
0034        if self.kw.has_key('joinMethodName'):
0035            self._joinMethodName = self.kw['joinMethodName']
0036            del self.kw['joinMethodName']
0037        return self.baseClass(soClass=soClass,
0038                              joinDef=self,
0039                              joinMethodName=self._joinMethodName,
0040                              **self.kw)
0041
0042# A join is separate from a foreign key, i.e., it is
0043# many-to-many, or one-to-many where the *other* class
0044# has the foreign key.
0045class SOJoin(object):
0046
0047    def __init__(self,
0048                 soClass=None,
0049                 otherClass=None,
0050                 joinColumn=None,
0051                 joinMethodName=None,
0052                 orderBy=NoDefault,
0053                 joinDef=None):
0054        self.soClass = soClass
0055        self.joinDef = joinDef
0056        self.otherClassName = otherClass
0057        classregistry.registry(soClass.sqlmeta.registry).addClassCallback(
0058            otherClass, self._setOtherClass)
0059        self.joinColumn = joinColumn
0060        self.joinMethodName = joinMethodName
0061        self._orderBy = orderBy
0062        if not self.joinColumn:
0063            # Here we set up the basic join, which is
0064            # one-to-many, where the other class points to
0065            # us.
0066            self.joinColumn = styles.getStyle(
0067                self.soClass).tableReference(self.soClass.sqlmeta.table)
0068
0069    def orderBy(self):
0070        if self._orderBy is NoDefault:
0071            self._orderBy = self.otherClass.sqlmeta.defaultOrder
0072        return self._orderBy
0073    orderBy = property(orderBy)
0074
0075    def _setOtherClass(self, cls):
0076        self.otherClass = cls
0077
0078    def hasIntermediateTable(self):
0079        return False
0080
0081    def _applyOrderBy(self, results, defaultSortClass):
0082        if self.orderBy is not None:
0083            results.sort(sorter(self.orderBy))
0084        return results
0085
0086def sorter(orderBy):
0087    if isinstance(orderBy, (tuple, list)):
0088        if len(orderBy) == 1:
0089            orderBy = orderBy[0]
0090        else:
0091            fhead = sorter(orderBy[0])
0092            frest = sorter(orderBy[1:])
0093            return lambda a, b, fhead=fhead, frest=frest: fhead(a, b) or frest(a, b)
0094    if isinstance(orderBy, sqlbuilder.DESC)          and isinstance(orderBy.expr, sqlbuilder.SQLObjectField):
0096        orderBy = '-' + orderBy.expr.original
0097    elif isinstance(orderBy, sqlbuilder.SQLObjectField):
0098        orderBy = orderBy.original
0099    # @@: but we don't handle more complex expressions for orderings
0100    if orderBy.startswith('-'):
0101        orderBy = orderBy[1:]
0102        reverse = True
0103    else:
0104        reverse = False
0105
0106    def cmper(a, b, attr=orderBy, rev=reverse):
0107        a = getattr(a, attr)
0108        b = getattr(b, attr)
0109        if rev:
0110            a, b = b, a
0111        if a is None:
0112            if b is None:
0113                return 0
0114            return -1
0115        if b is None:
0116            return 1
0117        return cmp(a, b)
0118    return cmper
0119
0120# This is a one-to-many
0121class SOMultipleJoin(SOJoin):
0122
0123    def __init__(self, addRemoveName=None, **kw):
0124        # addRemovePrefix is something like @@
0125        SOJoin.__init__(self, **kw)
0126
0127        # Here we generate the method names
0128        if not self.joinMethodName:
0129            name = self.otherClassName[0].lower() + self.otherClassName[1:]
0130            if name.endswith('s'):
0131                name = name + "es"
0132            else:
0133                name = name + "s"
0134            self.joinMethodName = name
0135        if not addRemoveName:
0136            self.addRemoveName = capitalize(self.otherClassName)
0137        else:
0138            self.addRemoveName = addRemoveName
0139
0140    def performJoin(self, inst):
0141        ids = inst._connection._SO_selectJoin(
0142            self.otherClass,
0143            self.joinColumn,
0144            inst.id)
0145        if inst.sqlmeta._perConnection:
0146            conn = inst._connection
0147        else:
0148            conn = None
0149        return self._applyOrderBy([self.otherClass.get(id, conn) for (id,) in ids if id is not None], self.otherClass)
0150
0151    def _dbNameToPythonName(self):
0152        for column in self.otherClass.sqlmeta.columns.values():
0153            if column.dbName == self.joinColumn:
0154                return column.name
0155        return self.soClass.sqlmeta.style.dbColumnToPythonAttr(self.joinColumn)
0156
0157class MultipleJoin(Join):
0158    baseClass = SOMultipleJoin
0159
0160class SOSQLMultipleJoin(SOMultipleJoin):
0161
0162    def performJoin(self, inst):
0163        if inst.sqlmeta._perConnection:
0164            conn = inst._connection
0165        else:
0166            conn = None
0167        pythonColumn = self._dbNameToPythonName()
0168        results = self.otherClass.select(getattr(self.otherClass.q, pythonColumn) == inst.id, connection=conn)
0169        return results.orderBy(self.orderBy)
0170
0171class SQLMultipleJoin(Join):
0172    baseClass = SOSQLMultipleJoin
0173
0174# This is a many-to-many join, with an intermediary table
0175class SORelatedJoin(SOMultipleJoin):
0176
0177    def __init__(self,
0178                 otherColumn=None,
0179                 intermediateTable=None,
0180                 createRelatedTable=True,
0181                 **kw):
0182        self.intermediateTable = intermediateTable
0183        self.otherColumn = otherColumn
0184        self.createRelatedTable = createRelatedTable
0185        SOMultipleJoin.__init__(self, **kw)
0186        classregistry.registry(
0187            self.soClass.sqlmeta.registry).addClassCallback(
0188            self.otherClassName, self._setOtherRelatedClass)
0189
0190    def _setOtherRelatedClass(self, otherClass):
0191        if not self.intermediateTable:
0192            names = [self.soClass.sqlmeta.table,
0193                     otherClass.sqlmeta.table]
0194            names.sort()
0195            self.intermediateTable = '%s_%s' % (names[0], names[1])
0196        if not self.otherColumn:
0197            self.otherColumn = self.soClass.sqlmeta.style.tableReference(
0198                otherClass.sqlmeta.table)
0199
0200
0201    def hasIntermediateTable(self):
0202        return True
0203
0204    def performJoin(self, inst):
0205        ids = inst._connection._SO_intermediateJoin(
0206            self.intermediateTable,
0207            self.otherColumn,
0208            self.joinColumn,
0209            inst.id)
0210        if inst.sqlmeta._perConnection:
0211            conn = inst._connection
0212        else:
0213            conn = None
0214        return self._applyOrderBy([self.otherClass.get(id, conn) for (id,) in ids if id is not None], self.otherClass)
0215
0216    def remove(self, inst, other):
0217        inst._connection._SO_intermediateDelete(
0218            self.intermediateTable,
0219            self.joinColumn,
0220            getID(inst),
0221            self.otherColumn,
0222            getID(other))
0223
0224    def add(self, inst, other):
0225        inst._connection._SO_intermediateInsert(
0226            self.intermediateTable,
0227            self.joinColumn,
0228            getID(inst),
0229            self.otherColumn,
0230            getID(other))
0231
0232class RelatedJoin(MultipleJoin):
0233    baseClass = SORelatedJoin
0234
0235# helper classes to SQLRelatedJoin
0236class OtherTableToJoin(sqlbuilder.SQLExpression):
0237    def __init__(self, otherTable, otherIdName, interTable, joinColumn):
0238        self.otherTable = otherTable
0239        self.otherIdName = otherIdName
0240        self.interTable = interTable
0241        self.joinColumn = joinColumn
0242
0243    def tablesUsedImmediate(self):
0244        return [self.otherTable, self.interTable]
0245
0246    def __sqlrepr__(self, db):
0247        return '%s.%s = %s.%s' % (self.otherTable, self.otherIdName, self.interTable, self.joinColumn)
0248
0249class JoinToTable(sqlbuilder.SQLExpression):
0250    def __init__(self, table, idName, interTable, joinColumn):
0251        self.table = table
0252        self.idName = idName
0253        self.interTable = interTable
0254        self.joinColumn = joinColumn
0255
0256    def tablesUsedImmediate(self):
0257        return [self.table, self.interTable]
0258
0259    def __sqlrepr__(self, db):
0260        return '%s.%s = %s.%s' % (self.interTable, self.joinColumn, self.table, self.idName)
0261
0262class TableToId(sqlbuilder.SQLExpression):
0263    def __init__(self, table, idName, idValue):
0264        self.table = table
0265        self.idName = idName
0266        self.idValue = idValue
0267
0268    def tablesUsedImmediate(self):
0269        return [self.table]
0270
0271    def __sqlrepr__(self, db):
0272        return '%s.%s = %s' % (self.table, self.idName, self.idValue)
0273
0274class SOSQLRelatedJoin(SORelatedJoin):
0275    def performJoin(self, inst):
0276        if inst.sqlmeta._perConnection:
0277            conn = inst._connection
0278        else:
0279            conn = None
0280        results = self.otherClass.select(sqlbuilder.AND(
0281            OtherTableToJoin(
0282                self.otherClass.sqlmeta.table, self.otherClass.sqlmeta.idName,
0283                self.intermediateTable, self.otherColumn
0284            ),
0285            JoinToTable(
0286                self.soClass.sqlmeta.table, self.soClass.sqlmeta.idName,
0287                self.intermediateTable, self.joinColumn
0288            ),
0289            TableToId(self.soClass.sqlmeta.table, self.soClass.sqlmeta.idName, inst.id),
0290        ), clauseTables=(self.soClass.sqlmeta.table, self.otherClass.sqlmeta.table, self.intermediateTable),
0291        connection=conn)
0292        return results.orderBy(self.orderBy)
0293
0294class SQLRelatedJoin(RelatedJoin):
0295    baseClass = SOSQLRelatedJoin
0296
0297def capitalize(name):
0298    return name[0].capitalize() + name[1:]
0299
0300class SOSingleJoin(SOMultipleJoin):
0301
0302    def __init__(self, **kw):
0303        self.makeDefault = kw.pop('makeDefault', False)
0304        SOMultipleJoin.__init__(self, **kw)
0305
0306    def performJoin(self, inst):
0307        if inst.sqlmeta._perConnection:
0308            conn = inst._connection
0309        else:
0310            conn = None
0311        pythonColumn = self._dbNameToPythonName()
0312        results = self.otherClass.select(
0313            getattr(self.otherClass.q, pythonColumn) == inst.id,
0314            connection=conn
0315        )
0316        if results.count() == 0:
0317            if not self.makeDefault:
0318                return None
0319            else:
0320                kw = {self.soClass.sqlmeta.style.instanceIDAttrToAttr(pythonColumn): inst}
0321                return self.otherClass(**kw) # instanciating the otherClass with all
0322        else:
0323            return results[0]
0324
0325class SingleJoin(Join):
0326    baseClass = SOSingleJoin
0327
0328
0329
0330import boundattributes
0331
0332class SOManyToMany(object):
0333
0334    def __init__(self, soClass, name, join,
0335                 intermediateTable, joinColumn, otherColumn,
0336                 createJoinTable, **attrs):
0337        self.name = name
0338        self.intermediateTable = intermediateTable
0339        self.joinColumn = joinColumn
0340        self.otherColumn = otherColumn
0341        self.createJoinTable = createJoinTable
0342        self.soClass = self.otherClass = None
0343        for name, value in attrs.items():
0344            setattr(self, name, value)
0345        classregistry.registry(
0346            soClass.sqlmeta.registry).addClassCallback(
0347            join, self._setOtherClass)
0348        classregistry.registry(
0349            soClass.sqlmeta.registry).addClassCallback(
0350            soClass.__name__, self._setThisClass)
0351
0352    def _setThisClass(self, soClass):
0353        self.soClass = soClass
0354        if self.soClass and self.otherClass:
0355            self._finishSet()
0356
0357    def _setOtherClass(self, otherClass):
0358        self.otherClass = otherClass
0359        if self.soClass and self.otherClass:
0360            self._finishSet()
0361
0362    def _finishSet(self):
0363        if self.intermediateTable is None:
0364            names = [self.soClass.sqlmeta.table,
0365                     self.otherClass.sqlmeta.table]
0366            names.sort()
0367            self.intermediateTable = '%s_%s' % (names[0], names[1])
0368        if not self.otherColumn:
0369            self.otherColumn = self.soClass.sqlmeta.style.tableReference(
0370                self.otherClass.sqlmeta.table)
0371        if not self.joinColumn:
0372            self.joinColumn = styles.getStyle(
0373                self.soClass).tableReference(self.soClass.sqlmeta.table)
0374        events.listen(self.event_CreateTableSignal,
0375                      self.soClass, events.CreateTableSignal)
0376        events.listen(self.event_CreateTableSignal,
0377                      self.otherClass, events.CreateTableSignal)
0378        self.clause = (
0379            (self.otherClass.q.id ==
0380             sqlbuilder.Field(self.intermediateTable, self.otherColumn))
0381            & (sqlbuilder.Field(self.intermediateTable, self.joinColumn)
0382               == self.soClass.q.id))
0383
0384    def __get__(self, obj, type):
0385        if obj is None:
0386            return self
0387        query = (
0388            (self.otherClass.q.id ==
0389             sqlbuilder.Field(self.intermediateTable, self.otherColumn))
0390            & (sqlbuilder.Field(self.intermediateTable, self.joinColumn)
0391               == obj.id))
0392        select = self.otherClass.select(query)
0393        return _ManyToManySelectWrapper(obj, self, select)
0394
0395    def event_CreateTableSignal(self, soClass, connection, extra_sql,
0396                                post_funcs):
0397        if self.createJoinTable:
0398            post_funcs.append(self.event_CreateTableSignalPost)
0399
0400    def event_CreateTableSignalPost(self, soClass, connection):
0401        if connection.tableExists(self.intermediateTable):
0402            return
0403        connection._SO_createJoinTable(self)
0404
0405class ManyToMany(boundattributes.BoundFactory):
0406    factory_class = SOManyToMany
0407    __restrict_attributes__ = (
0408        'join', 'intermediateTable',
0409        'joinColumn', 'otherColumn', 'createJoinTable')
0410    __unpackargs__ = ('join',)
0411
0412    # Default values:
0413    intermediateTable = None
0414    joinColumn = None
0415    otherColumn = None
0416    createJoinTable = True
0417
0418class _ManyToManySelectWrapper(object):
0419
0420    def __init__(self, forObject, join, select):
0421        self.forObject = forObject
0422        self.join = join
0423        self.select = select
0424
0425    def __getattr__(self, attr):
0426        # @@: This passes through private variable access too... should it?
0427        # Also magic methods, like __str__
0428        return getattr(self.select, attr)
0429
0430    def __repr__(self):
0431        return '<%s for: %s>' % (self.__class__.__name__, repr(self.select))
0432
0433    def __str__(self):
0434        return str(self.select)
0435
0436    def __iter__(self):
0437        return iter(self.select)
0438
0439    def __getitem__(self, key):
0440        return self.select[key]
0441
0442    def add(self, obj):
0443        obj._connection._SO_intermediateInsert(
0444            self.join.intermediateTable,
0445            self.join.joinColumn,
0446            getID(self.forObject),
0447            self.join.otherColumn,
0448            getID(obj))
0449
0450    def remove(self, obj):