0001from datetime import datetime
0002from sqlobject import col, events, SQLObject, AND
0003
0004
0005class Version(SQLObject):
0006    def restore(self):
0007        values = self.sqlmeta.asDict()
0008        del values['id']
0009        del values['masterID']
0010        del values['dateArchived']
0011        for _col in self.extraCols:
0012            del values[_col]
0013        self.masterClass.get(self.masterID).set(**values)
0014
0015    def nextVersion(self):
0016        version = self.select(
0017            AND(self.q.masterID == self.masterID, self.q.id > self.id),
0018            orderBy=self.q.id)
0019        if version.count():
0020            return version[0]
0021        else:
0022            return self.master
0023
0024    def getChangedFields(self):
0025        next = self.nextVersion()
0026        columns = self.masterClass.sqlmeta.columns
0027        fields = []
0028        for column in columns:
0029            if column not in ["dateArchived", "id", "masterID"]:
0030                if getattr(self, column) != getattr(next, column):
0031                    fields.append(column.title())
0032
0033        return fields
0034
0035    @classmethod
0036    def select(cls, clause=None, *args, **kw):
0037        if not getattr(cls, '_connection', None):
0038            cls._connection = cls.masterClass._connection
0039        return super(Version, cls).select(clause, *args, **kw)
0040
0041    def __getattr__(self, attr):
0042        if attr in self.__dict__:
0043            return self.__dict__[attr]
0044        else:
0045            return getattr(self.master, attr)
0046
0047
0048def getColumns(columns, cls):
0049    for column, defi in cls.sqlmeta.columnDefinitions.items():
0050        if column.endswith("ID") and isinstance(defi, col.ForeignKey):
0051            column = column[:-2]
0052
0053        # remove incompatible constraints
0054        kwds = dict(defi._kw)
0055        for kw in ["alternateID", "unique"]:
0056            if kw in kwds:
0057                del kwds[kw]
0058        columns[column] = defi.__class__(**kwds)
0059
0060    # ascend heirarchy
0061    if cls.sqlmeta.parentClass:
0062        getColumns(columns, cls.sqlmeta.parentClass)
0063
0064
0065class Versioning(object):
0066    def __init__(self, extraCols=None):
0067        if extraCols:
0068            self.extraCols = extraCols
0069        else:
0070            self.extraCols = {}
0071        pass
0072
0073    def __addtoclass__(self, soClass, name):
0074        self.name = name
0075        self.soClass = soClass
0076
0077        attrs = {'dateArchived': col.DateTimeCol(default=datetime.now),
0078                 'master': col.ForeignKey(self.soClass.__name__),
0079                 'masterClass': self.soClass,
0080                 'extraCols': self.extraCols
0081                 }
0082
0083        getColumns(attrs, self.soClass)
0084
0085        attrs.update(self.extraCols)
0086
0087        self.versionClass = type(self.soClass.__name__ + 'Versions',
0088                                 (Version,),
0089                                 attrs)
0090
0091        if '_connection' in self.soClass.__dict__:
0092            self.versionClass._connection =                   self.soClass.__dict__['_connection']
0094
0095        events.listen(self.createTable,
0096                      soClass, events.CreateTableSignal)
0097        events.listen(self.rowUpdate, soClass,
0098                      events.RowUpdateSignal)
0099
0100    def createVersionTable(self, cls, conn):
0101        self.versionClass.createTable(ifNotExists=True, connection=conn)
0102
0103    def createTable(self, soClass, connection, extra_sql, post_funcs):
0104        assert soClass is self.soClass
0105        post_funcs.append(self.createVersionTable)
0106
0107    def rowUpdate(self, instance, kwargs):
0108        if instance.childName and instance.childName != self.soClass.__name__:
0109            return  # if you want your child class versioned, version it
0110
0111        values = instance.sqlmeta.asDict()
0112        del values['id']
0113        values['masterID'] = instance.id
0114        self.versionClass(connection=instance._connection, **values)
0115
0116    def __get__(self, obj, type=None):
0117        if obj is None:
0118            return self
0119        return self.versionClass.select(
0120            self.versionClass.q.masterID == obj.id, connection=obj._connection)