0001from itertools import count
0002from . import boundattributes
0003from . import classregistry
0004from . import events
0005from . import styles
0006from . import sqlbuilder
0007from .styles import capword
0008
0009__all__ = ['MultipleJoin', 'SQLMultipleJoin', 'RelatedJoin', 'SQLRelatedJoin',
0010           'SingleJoin', 'ManyToMany', 'OneToMany']
0011
0012creationOrder = count()
0013NoDefault = sqlbuilder.NoDefault
0014
0015
0016def getID(obj):
0017    try:
0018        return obj.id
0019    except AttributeError:
0020        return int(obj)
0021
0022
0023class Join(object):
0024
0025    def __init__(self, otherClass=None, **kw):
0026        kw['otherClass'] = otherClass
0027        self.kw = kw
0028        self._joinMethodName = self.kw.pop('joinMethodName', None)
0029        self.creationOrder = next(creationOrder)
0030
0031    def _set_joinMethodName(self, value):
0032        assert self._joinMethodName == value or self._joinMethodName is None,               "You have already given an explicit joinMethodName (%s), "               "and you are now setting it to %s" % (self._joinMethodName, value)
0035        self._joinMethodName = value
0036
0037    def _get_joinMethodName(self):
0038        return self._joinMethodName
0039
0040    joinMethodName = property(_get_joinMethodName, _set_joinMethodName)
0041    name = joinMethodName
0042
0043    def withClass(self, soClass):
0044        if 'joinMethodName' in self.kw:
0045            self._joinMethodName = self.kw['joinMethodName']
0046            del self.kw['joinMethodName']
0047        return self.baseClass(creationOrder=self.creationOrder,
0048                              soClass=soClass,
0049                              joinDef=self,
0050                              joinMethodName=self._joinMethodName,
0051                              **self.kw)
0052
0053
0054# A join is separate from a foreign key, i.e., it is
0055# many-to-many, or one-to-many where the *other* class
0056# has the foreign key.
0057
0058
0059class SOJoin(object):
0060
0061    def __init__(self,
0062                 creationOrder,
0063                 soClass=None,
0064                 otherClass=None,
0065                 joinColumn=None,
0066                 joinMethodName=None,
0067                 orderBy=NoDefault,
0068                 joinDef=None):
0069        self.creationOrder = creationOrder
0070        self.soClass = soClass
0071        self.joinDef = joinDef
0072        self.otherClassName = otherClass
0073        classregistry.registry(soClass.sqlmeta.registry).addClassCallback(
0074            otherClass, self._setOtherClass)
0075        self.joinColumn = joinColumn
0076        self.joinMethodName = joinMethodName
0077        self._orderBy = orderBy
0078        if not self.joinColumn:
0079            # Here we set up the basic join, which is
0080            # one-to-many, where the other class points to
0081            # us.
0082            self.joinColumn = styles.getStyle(
0083                self.soClass).tableReference(self.soClass.sqlmeta.table)
0084
0085    def orderBy(self):
0086        if self._orderBy is NoDefault:
0087            self._orderBy = self.otherClass.sqlmeta.defaultOrder
0088        return self._orderBy
0089    orderBy = property(orderBy)
0090
0091    def _setOtherClass(self, cls):
0092        self.otherClass = cls
0093
0094    def hasIntermediateTable(self):
0095        return False
0096
0097    def _applyOrderBy(self, results, defaultSortClass):
0098        if self.orderBy is not None:
0099            doSort(results, self.orderBy)
0100        return results
0101
0102
0103class MinType(object):
0104    """Sort less than everything, for handling None's in the results"""
0105    # functools.total_ordering would simplify this, but isn't available
0106    # for python 2.6
0107
0108    def __lt__(self, other):
0109        if self is other:
0110            return False
0111        return True
0112
0113    def __eq__(self, other):
0114        return self is other
0115
0116    def __gt__(self, other):
0117        return False
0118
0119    def __le__(self, other):
0120        return True
0121
0122    def __ge__(self, other):
0123        if self is other:
0124            return True
0125        return False
0126
0127
0128Min = MinType()
0129
0130
0131def doSort(results, orderBy):
0132    if isinstance(orderBy, (tuple, list)):
0133        if len(orderBy) == 1:
0134            orderBy = orderBy[0]
0135        else:
0136            # Rely on stable sort results, since this is simpler
0137            # than trying to munge everything into a single sort key
0138            doSort(results, orderBy[0])
0139            doSort(results, orderBy[1:])
0140            return
0141    if isinstance(orderBy, sqlbuilder.DESC)          and isinstance(orderBy.expr, sqlbuilder.SQLObjectField):
0143        orderBy = '-' + orderBy.expr.original
0144    elif isinstance(orderBy, sqlbuilder.SQLObjectField):
0145        orderBy = orderBy.original
0146    # @@: but we don't handle more complex expressions for orderings
0147    if orderBy.startswith('-'):
0148        orderBy = orderBy[1:]
0149        reverse = True
0150    else:
0151        reverse = False
0152
0153    def sortkey(x, attr=orderBy):
0154        a = getattr(x, attr)
0155        if a is None:
0156            return Min
0157        return a
0158    results.sort(key=sortkey, reverse=reverse)
0159
0160
0161# This is a one-to-many
0162
0163
0164class SOMultipleJoin(SOJoin):
0165
0166    def __init__(self, addRemoveName=None, **kw):
0167        # addRemovePrefix is something like @@
0168        SOJoin.__init__(self, **kw)
0169
0170        # Here we generate the method names
0171        if not self.joinMethodName:
0172            name = self.otherClassName[0].lower() + self.otherClassName[1:]
0173            if name.endswith('s'):
0174                name = name + "es"
0175            else:
0176                name = name + "s"
0177            self.joinMethodName = name
0178        if addRemoveName:
0179            self.addRemoveName = addRemoveName
0180        else:
0181            self.addRemoveName = capword(self.otherClassName)
0182
0183    def performJoin(self, inst):
0184        ids = inst._connection._SO_selectJoin(
0185            self.otherClass,
0186            self.joinColumn,
0187            inst.id)
0188        if inst.sqlmeta._perConnection:
0189            conn = inst._connection
0190        else:
0191            conn = None
0192        return self._applyOrderBy(
0193            [self.otherClass.get(id, conn) for (id,) in ids if id is not None],
0194            self.otherClass)
0195
0196    def _dbNameToPythonName(self):
0197        for column in self.otherClass.sqlmeta.columns.values():
0198            if column.dbName == self.joinColumn:
0199                return column.name
0200        return self.soClass.sqlmeta.style.dbColumnToPythonAttr(self.joinColumn)
0201
0202
0203class MultipleJoin(Join):
0204    baseClass = SOMultipleJoin
0205
0206
0207class SOSQLMultipleJoin(SOMultipleJoin):
0208
0209    def performJoin(self, inst):
0210        if inst.sqlmeta._perConnection:
0211            conn = inst._connection
0212        else:
0213            conn = None
0214        pythonColumn = self._dbNameToPythonName()
0215        results = self.otherClass.select(
0216            getattr(self.otherClass.q, pythonColumn) == inst.id,
0217            connection=conn)
0218        return results.orderBy(self.orderBy)
0219
0220
0221class SQLMultipleJoin(Join):
0222    baseClass = SOSQLMultipleJoin
0223
0224
0225# This is a many-to-many join, with an intermediary table
0226
0227
0228class SORelatedJoin(SOMultipleJoin):
0229
0230    def __init__(self,
0231                 otherColumn=None,
0232                 intermediateTable=None,
0233                 createRelatedTable=True,
0234                 **kw):
0235        self.intermediateTable = intermediateTable
0236        self.otherColumn = otherColumn
0237        self.createRelatedTable = createRelatedTable
0238        SOMultipleJoin.__init__(self, **kw)
0239        classregistry.registry(
0240            self.soClass.sqlmeta.registry).addClassCallback(
0241            self.otherClassName, self._setOtherRelatedClass)
0242
0243    def _setOtherRelatedClass(self, otherClass):
0244        if not self.intermediateTable:
0245            names = [self.soClass.sqlmeta.table,
0246                     otherClass.sqlmeta.table]
0247            names.sort()
0248            self.intermediateTable = '%s_%s' % (names[0], names[1])
0249        if not self.otherColumn:
0250            self.otherColumn = self.soClass.sqlmeta.style.tableReference(
0251                otherClass.sqlmeta.table)
0252
0253    def hasIntermediateTable(self):
0254        return True
0255
0256    def performJoin(self, inst):
0257        ids = inst._connection._SO_intermediateJoin(
0258            self.intermediateTable,
0259            self.otherColumn,
0260            self.joinColumn,
0261            inst.id)
0262        if inst.sqlmeta._perConnection:
0263            conn = inst._connection
0264        else:
0265            conn = None
0266        return self._applyOrderBy(
0267            [self.otherClass.get(id, conn) for (id,) in ids if id is not None],
0268            self.otherClass)
0269
0270    def remove(self, inst, other):
0271        inst._connection._SO_intermediateDelete(
0272            self.intermediateTable,
0273            self.joinColumn,
0274            getID(inst),
0275            self.otherColumn,
0276            getID(other))
0277
0278    def add(self, inst, other):
0279        inst._connection._SO_intermediateInsert(
0280            self.intermediateTable,
0281            self.joinColumn,
0282            getID(inst),
0283            self.otherColumn,
0284            getID(other))
0285
0286
0287class RelatedJoin(MultipleJoin):
0288    baseClass = SORelatedJoin
0289
0290
0291# helper classes to SQLRelatedJoin
0292
0293
0294class OtherTableToJoin(sqlbuilder.SQLExpression):
0295    def __init__(self, otherTable, otherIdName, interTable, joinColumn):
0296        self.otherTable = otherTable
0297        self.otherIdName = otherIdName
0298        self.interTable = interTable
0299        self.joinColumn = joinColumn
0300
0301    def tablesUsedImmediate(self):
0302        return [self.otherTable, self.interTable]
0303
0304    def __sqlrepr__(self, db):
0305        return '%s.%s = %s.%s' % (self.otherTable, self.otherIdName,
0306                                  self.interTable, self.joinColumn)
0307
0308
0309class JoinToTable(sqlbuilder.SQLExpression):
0310    def __init__(self, table, idName, interTable, joinColumn):
0311        self.table = table
0312        self.idName = idName
0313        self.interTable = interTable
0314        self.joinColumn = joinColumn
0315
0316    def tablesUsedImmediate(self):
0317        return [self.table, self.interTable]
0318
0319    def __sqlrepr__(self, db):
0320        return '%s.%s = %s.%s' % (self.interTable, self.joinColumn, self.table,
0321                                  self.idName)
0322
0323
0324class TableToId(sqlbuilder.SQLExpression):
0325    def __init__(self, table, idName, idValue):
0326        self.table = table
0327        self.idName = idName
0328        self.idValue = idValue
0329
0330    def tablesUsedImmediate(self):
0331        return [self.table]
0332
0333    def __sqlrepr__(self, db):
0334        return '%s.%s = %s' % (self.table, self.idName, self.idValue)
0335
0336
0337class SOSQLRelatedJoin(SORelatedJoin):
0338    def performJoin(self, inst):
0339        if inst.sqlmeta._perConnection:
0340            conn = inst._connection
0341        else:
0342            conn = None
0343        results = self.otherClass.select(sqlbuilder.AND(
0344            OtherTableToJoin(
0345                self.otherClass.sqlmeta.table, self.otherClass.sqlmeta.idName,
0346                self.intermediateTable, self.otherColumn
0347            ),
0348            JoinToTable(
0349                self.soClass.sqlmeta.table, self.soClass.sqlmeta.idName,
0350                self.intermediateTable, self.joinColumn
0351            ),
0352            TableToId(self.soClass.sqlmeta.table, self.soClass.sqlmeta.idName,
0353                      inst.id),
0354        ), clauseTables=(self.soClass.sqlmeta.table,
0355                         self.otherClass.sqlmeta.table,
0356                         self.intermediateTable),
0357            connection=conn)
0358        return results.orderBy(self.orderBy)
0359
0360
0361class SQLRelatedJoin(RelatedJoin):
0362    baseClass = SOSQLRelatedJoin
0363
0364
0365class SOSingleJoin(SOMultipleJoin):
0366
0367    def __init__(self, **kw):
0368        self.makeDefault = kw.pop('makeDefault', False)
0369        SOMultipleJoin.__init__(self, **kw)
0370
0371    def performJoin(self, inst):
0372        if inst.sqlmeta._perConnection:
0373            conn = inst._connection
0374        else:
0375            conn = None
0376        pythonColumn = self._dbNameToPythonName()
0377        results = self.otherClass.select(
0378            getattr(self.otherClass.q, pythonColumn) == inst.id,
0379            connection=conn
0380        )
0381        if results.count() == 0:
0382            if not self.makeDefault:
0383                return None
0384            else:
0385                kw = {self.soClass.sqlmeta.style.
0386                      instanceIDAttrToAttr(pythonColumn): inst}
0387                # instanciating the otherClass with all
0388                return self.otherClass(**kw)
0389        else:
0390            return results[0]
0391
0392
0393class SingleJoin(Join):
0394    baseClass = SOSingleJoin
0395
0396
0397class SOManyToMany(object):
0398
0399    def __init__(self, soClass, name, join,
0400                 intermediateTable, joinColumn, otherColumn,
0401                 createJoinTable, **attrs):
0402        self.name = name
0403        self.intermediateTable = intermediateTable
0404        self.joinColumn = joinColumn
0405        self.otherColumn = otherColumn
0406        self.createJoinTable = createJoinTable
0407        self.soClass = self.otherClass = None
0408        for name, value in attrs.items():
0409            setattr(self, name, value)
0410        classregistry.registry(
0411            soClass.sqlmeta.registry).addClassCallback(
0412            join, self._setOtherClass)
0413        classregistry.registry(
0414            soClass.sqlmeta.registry).addClassCallback(
0415            soClass.__name__, self._setThisClass)
0416
0417    def _setThisClass(self, soClass):
0418        self.soClass = soClass
0419        if self.soClass and self.otherClass:
0420            self._finishSet()
0421
0422    def _setOtherClass(self, otherClass):
0423        self.otherClass = otherClass
0424        if self.soClass and self.otherClass:
0425            self._finishSet()
0426
0427    def _finishSet(self):
0428        if self.intermediateTable is None:
0429            names = [self.soClass.sqlmeta.table,
0430                     self.otherClass.sqlmeta.table]
0431            names.sort()
0432            self.intermediateTable = '%s_%s' % (names[0], names[1])
0433        if not self.otherColumn:
0434            self.otherColumn = self.soClass.sqlmeta.style.tableReference(
0435                self.otherClass.sqlmeta.table)
0436        if not self.joinColumn:
0437            self.joinColumn = styles.getStyle(
0438                self.soClass).tableReference(self.soClass.sqlmeta.table)
0439        events.listen(self.event_CreateTableSignal,
0440                      self.soClass, events.CreateTableSignal)
0441        events.listen(self.event_CreateTableSignal,
0442                      self.otherClass, events.CreateTableSignal)
0443        self.clause = (
0444            (self.otherClass.q.id ==
0445             sqlbuilder.Field(self.intermediateTable, self.otherColumn)) &
0446            (sqlbuilder.Field(self.intermediateTable, self.joinColumn) ==
0447             self.soClass.q.id))
0448
0449    def __get__(self, obj, type):
0450        if obj is None:
0451            return self
0452        query = (
0453            (self.otherClass.q.id ==
0454             sqlbuilder.Field(self.intermediateTable, self.otherColumn)) &
0455            (sqlbuilder.Field(self.intermediateTable, self.joinColumn) ==
0456             obj.id))
0457        select = self.otherClass.select(query)
0458        return _ManyToManySelectWrapper(obj, self, select)
0459
0460    def event_CreateTableSignal(self, soClass, connection, extra_sql,
0461                                post_funcs):
0462        if self.createJoinTable:
0463            post_funcs.append(self.event_CreateTableSignalPost)
0464
0465    def event_CreateTableSignalPost(self, soClass, connection):
0466        if connection.tableExists(self.intermediateTable):
0467            return
0468        connection._SO_createJoinTable(self)
0469
0470
0471class ManyToMany(boundattributes.BoundFactory):
0472    factory_class = SOManyToMany
0473    __restrict_attributes__ = (
0474        'join', 'intermediateTable',
0475        'joinColumn', 'otherColumn', 'createJoinTable')
0476    __unpackargs__ = ('join',)
0477
0478    # Default values:
0479    intermediateTable = None
0480    joinColumn = None
0481    otherColumn = None
0482    createJoinTable = True
0483
0484
0485class _ManyToManySelectWrapper(object):
0486
0487    def __init__(self, forObject, join, select):
0488        self.forObject = forObject
0489        self.join = join
0490        self.select = select
0491
0492    def __getattr__(self, attr):
0493        # @@: This passes through private variable access too... should it?
0494        # Also magic methods, like __str__
0495        return getattr(self.select, attr)
0496
0497    def __repr__(self):
0498        return '<%s for: %s>' % (self.__class__.__name__, repr(self.select))
0499
0500    def __str__(self):
0501        return str(self.select)
0502
0503    def __iter__(self):
0504        return iter(self.select)
0505
0506    def __getitem__(self, key):
0507        return self.select[key]
0508
0509    def add(self, obj):
0510        obj._connection._SO_intermediateInsert(
0511            self.join.intermediateTable,
0512            self.join.joinColumn,
0513            getID(self.forObject),
0514            self.join.otherColumn,
0515            getID(obj))
0516
0517    def remove(self, obj):
0518        obj._connection._SO_intermediateDelete(
0519            self.join.intermediateTable,
0520            self.join.joinColumn,
0521            getID(self.forObject),
0522            self.join.otherColumn,
0523            getID(obj))
0524
0525    def create(self, **kw):
0526        obj = self.join.otherClass(**kw)
0527        self.add(obj)
0528        return obj
0529
0530
0531class SOOneToMany(object):
0532
0533    def __init__(self, soClass, name, join, joinColumn, **attrs):
0534        self.soClass = soClass
0535        self.name = name
0536        self.joinColumn = joinColumn
0537        for name, value in attrs.items():
0538            setattr(self, name, value)
0539        classregistry.registry(
0540            soClass.sqlmeta.registry).addClassCallback(
0541            join, self._setOtherClass)
0542
0543    def _setOtherClass(self, otherClass):
0544        self.otherClass = otherClass
0545        if not self.joinColumn:
0546            self.joinColumn = styles.getStyle(
0547                self.soClass).tableReference(self.soClass.sqlmeta.table)
0548        self.clause = (
0549            sqlbuilder.Field(self.otherClass.sqlmeta.table, self.joinColumn) ==
0550            self.soClass.q.id)
0551
0552    def __get__(self, obj, type):
0553        if obj is None:
0554            return self
0555        query = (
0556            sqlbuilder.Field(self.otherClass.sqlmeta.table, self.joinColumn) ==
0557            obj.id)
0558        select = self.otherClass.select(query)
0559        return _OneToManySelectWrapper(obj, self, select)
0560
0561
0562class OneToMany(boundattributes.BoundFactory):
0563    factory_class = SOOneToMany
0564    __restrict_attributes__ = (
0565        'join', 'joinColumn')
0566    __unpackargs__ = ('join',)
0567
0568    # Default values:
0569    joinColumn = None
0570
0571
0572class _OneToManySelectWrapper(object):
0573
0574    def __init__(self, forObject, join, select):
0575        self.forObject = forObject
0576        self.join = join
0577        self.select = select
0578
0579    def __getattr__(self, attr):
0580        # @@: This passes through private variable access too... should it?
0581        # Also magic methods, like __str__
0582        return getattr(self.select, attr)
0583
0584    def __repr__(self):
0585        return '<%s for: %s>' % (self.__class__.__name__, repr(self.select))
0586
0587    def __str__(self):
0588        return str(self.select)
0589
0590    def __iter__(self):
0591        return iter(self.select)
0592
0593    def __getitem__(self, key):
0594        return self.select[key]
0595
0596    def create(self, **kw):
0597        kw[self.join.joinColumn] = self.forObject.id
0598        return self.join.otherClass(**kw)