0001from sqlobject.dbconnection import DBAPI
0002import re
0003from sqlobject import col
0004from sqlobject import sqlbuilder
0005from sqlobject.converters import registerConverter
0006psycopg = None
0007pgdb = None
0008
0009class PostgresConnection(DBAPI):
0010
0011 supportTransactions = True
0012 dbName = 'postgres'
0013 schemes = [dbName, 'postgresql', 'psycopg']
0014
0015 def __init__(self, dsn=None, host=None, port=None, db=None,
0016 user=None, passwd=None, usePygresql=False, unicodeCols=False,
0017 **kw):
0018 global psycopg, pgdb
0019 self.usePygresql = usePygresql
0020 if usePygresql:
0021 if pgdb is None:
0022 import pgdb
0023 self.module = pgdb
0024 else:
0025 if psycopg is None:
0026 import psycopg
0027 self.module = psycopg
0028
0029
0030 registerConverter(type(psycopg.Binary('')),
0031 PsycoBinaryConverter)
0032
0033 self.user = user
0034 self.host = host
0035 self.port = port
0036 self.db = db
0037 self.password = passwd
0038 self.dsn_dict = dsn_dict = {}
0039 if host:
0040 dsn_dict["host"] = host
0041 if port:
0042 if usePygresql:
0043 dsn_dict["host"] = "%s:%d" % (host, port)
0044 else:
0045 dsn_dict["port"] = str(port)
0046 if db:
0047 dsn_dict["database"] = db
0048 if user:
0049 dsn_dict["user"] = user
0050 if passwd:
0051 dsn_dict["password"] = passwd
0052 self.use_dsn = dsn is not None
0053 if dsn is None:
0054 if usePygresql:
0055 dsn = ''
0056 if host:
0057 dsn += host
0058 dsn += ':'
0059 if db:
0060 dsn += db
0061 dsn += ':'
0062 if user:
0063 dsn += user
0064 dsn += ':'
0065 if passwd:
0066 dsn += passwd
0067 else:
0068 dsn = []
0069 if db:
0070 dsn.append('dbname=%s' % db)
0071 if user:
0072 dsn.append('user=%s' % user)
0073 if passwd:
0074 dsn.append('password=%s' % passwd)
0075 if host:
0076 dsn.append('host=%s' % host)
0077 if port:
0078 dsn.append('port=%d' % port)
0079 dsn = ' '.join(dsn)
0080 self.dsn = dsn
0081 self.unicodeCols = unicodeCols
0082 DBAPI.__init__(self, **kw)
0083
0084
0085 self._server_version = None
0086
0087 def connectionFromURI(cls, uri):
0088 user, password, host, port, path, args = cls._parseURI(uri)
0089 path = path.strip('/')
0090 return cls(host=host, port=port, db=path, user=user, passwd=password, **args)
0091 connectionFromURI = classmethod(connectionFromURI)
0092
0093 def _setAutoCommit(self, conn, auto):
0094
0095 if hasattr(conn, 'autocommit'):
0096 conn.autocommit(auto)
0097
0098 def makeConnection(self):
0099 try:
0100 if self.use_dsn:
0101 conn = self.module.connect(self.dsn)
0102 else:
0103 conn = self.module.connect(**self.dsn_dict)
0104 except self.module.OperationalError, e:
0105 raise self.module.OperationalError("%s; used connection string %r" % (e, self.dsn))
0106 if self.autoCommit:
0107
0108 if hasattr(conn, 'autocommit'):
0109 conn.autocommit(1)
0110 return conn
0111
0112 def _queryInsertID(self, conn, soInstance, id, names, values):
0113 table = soInstance.sqlmeta.table
0114 idName = soInstance.sqlmeta.idName
0115 sequenceName = getattr(soInstance, '_idSequence',
0116 '%s_%s_seq' % (table, idName))
0117 c = conn.cursor()
0118 if id is None:
0119 c.execute("SELECT NEXTVAL('%s')" % sequenceName)
0120 id = c.fetchone()[0]
0121 names = [idName] + names
0122 values = [id] + values
0123 q = self._insertSQL(table, names, values)
0124 if self.debug:
0125 self.printDebug(conn, q, 'QueryIns')
0126 c.execute(q)
0127 if self.debugOutput:
0128 self.printDebug(conn, id, 'QueryIns', 'result')
0129 return id
0130
0131 def _queryAddLimitOffset(self, query, start, end):
0132 if not start:
0133 return "%s LIMIT %i" % (query, end)
0134 if not end:
0135 return "%s OFFSET %i" % (query, start)
0136 return "%s LIMIT %i OFFSET %i" % (query, end-start, start)
0137
0138 def createColumn(self, soClass, col):
0139 return col.postgresCreateSQL()
0140
0141 def createIndexSQL(self, soClass, index):
0142 return index.postgresCreateIndexSQL(soClass)
0143
0144 def createIDColumn(self, soClass):
0145 return '%s SERIAL PRIMARY KEY' % soClass.sqlmeta.idName
0146
0147 def dropTable(self, tableName, cascade=False):
0148 if self.server_version[:3] <= "7.2":
0149 cascade=False
0150 self.query("DROP TABLE %s %s" % (tableName,
0151 cascade and 'CASCADE' or ''))
0152
0153 def joinSQLType(self, join):
0154 return 'INT NOT NULL'
0155
0156 def tableExists(self, tableName):
0157 result = self.queryOne("SELECT COUNT(relname) FROM pg_class WHERE relname = %s"
0158 % self.sqlrepr(tableName))
0159 return result[0]
0160
0161 def addColumn(self, tableName, column):
0162 self.query('ALTER TABLE %s ADD COLUMN %s' %
0163 (tableName,
0164 column.postgresCreateSQL()))
0165
0166 def delColumn(self, tableName, column):
0167 self.query('ALTER TABLE %s DROP COLUMN %s' %
0168 (tableName,
0169 column.dbName))
0170
0171 def columnsFromSchema(self, tableName, soClass):
0172
0173 keyQuery = """
0174 SELECT pg_catalog.pg_get_constraintdef(oid) as condef
0175 FROM pg_catalog.pg_constraint r
0176 WHERE r.conrelid = %s::regclass AND r.contype = 'f'"""
0177
0178 colQuery = """
0179 SELECT a.attname,
0180 pg_catalog.format_type(a.atttypid, a.atttypmod), a.attnotnull,
0181 (SELECT substring(d.adsrc for 128) FROM pg_catalog.pg_attrdef d
0182 WHERE d.adrelid=a.attrelid AND d.adnum = a.attnum)
0183 FROM pg_catalog.pg_attribute a
0184 WHERE a.attrelid =%s::regclass
0185 AND a.attnum > 0 AND NOT a.attisdropped
0186 ORDER BY a.attnum"""
0187
0188 primaryKeyQuery = """
0189 SELECT pg_index.indisprimary,
0190 pg_catalog.pg_get_indexdef(pg_index.indexrelid)
0191 FROM pg_catalog.pg_class c, pg_catalog.pg_class c2,
0192 pg_catalog.pg_index AS pg_index
0193 WHERE c.relname = %s
0194 AND c.oid = pg_index.indrelid
0195 AND pg_index.indexrelid = c2.oid
0196 AND pg_index.indisprimary
0197 """
0198
0199 keyData = self.queryAll(keyQuery % self.sqlrepr(tableName))
0200 keyRE = re.compile(r"\((.+)\) REFERENCES (.+)\(")
0201 keymap = {}
0202
0203 for (condef,) in keyData:
0204 match = keyRE.search(condef)
0205 if match:
0206 field, reftable = match.groups()
0207 keymap[field] = reftable.capitalize()
0208
0209 primaryData = self.queryAll(primaryKeyQuery % self.sqlrepr(tableName))
0210 primaryRE = re.compile(r'CREATE .*? USING .* \((.+?)\)')
0211 primaryKey = None
0212 for isPrimary, indexDef in primaryData:
0213 match = primaryRE.search(indexDef)
0214 assert match, "Unparseable contraint definition: %r" % indexDef
0215 assert primaryKey is None, "Already found primary key (%r), then found: %r" % (primaryKey, indexDef)
0216 primaryKey = match.group(1)
0217 assert primaryKey, "No primary key found in table %r" % tableName
0218 if primaryKey.startswith('"'):
0219 assert primaryKey.endswith('"')
0220 primaryKey = primaryKey[1:-1]
0221
0222 colData = self.queryAll(colQuery % self.sqlrepr(tableName))
0223 results = []
0224 if self.unicodeCols:
0225 client_encoding = self.queryOne("SHOW client_encoding")[0]
0226 for field, t, notnull, defaultstr in colData:
0227 if field == primaryKey:
0228 continue
0229 colClass, kw = self.guessClass(t)
0230 if self.unicodeCols and colClass == col.StringCol:
0231 colClass = col.UnicodeCol
0232 kw['dbEncoding'] = client_encoding
0233 kw['name'] = soClass.sqlmeta.style.dbColumnToPythonAttr(field)
0234 kw['dbName'] = field
0235 kw['notNone'] = notnull
0236 if defaultstr is not None:
0237 kw['default'] = self.defaultFromSchema(colClass, defaultstr)
0238 elif not notnull:
0239 kw['default'] = None
0240 if keymap.has_key(field):
0241 kw['foreignKey'] = keymap[field]
0242 results.append(colClass(**kw))
0243 return results
0244
0245 def guessClass(self, t):
0246 if t.count('int'):
0247 return col.IntCol, {}
0248 elif t.count('varying'):
0249 if '(' in t:
0250 return col.StringCol, {'length': int(t[t.index('(')+1:-1])}
0251 else:
0252 return col.StringCol, {}
0253 elif t.startswith('character('):
0254 return col.StringCol, {'length': int(t[t.index('(')+1:-1]),
0255 'varchar': False}
0256 elif t == 'text':
0257 return col.StringCol, {}
0258 elif t.startswith('datetime'):
0259 return col.DateTimeCol, {}
0260 elif t.startswith('bool'):
0261 return col.BoolCol, {}
0262 elif t.startswith('bytea'):
0263 return col.BLOBCol, {}
0264 else:
0265 return col.Col, {}
0266
0267 def defaultFromSchema(self, colClass, defaultstr):
0268 """
0269 If the default can be converted to a python constant, convert it.
0270 Otherwise return is as a sqlbuilder constant.
0271 """
0272 if colClass == col.BoolCol:
0273 if defaultstr == 'false':
0274 return False
0275 elif defaultstr == 'true':
0276 return True
0277 return getattr(sqlbuilder.const, defaultstr)
0278
0279 def server_version(self):
0280 if self._server_version is None:
0281
0282
0283 server_version = self.queryOne("SELECT version()")[0]
0284 self._server_version = server_version.split()[1]
0285 return self._server_version
0286 server_version = property(server_version)
0287
0288 def createEmptyDatabase(self):
0289
0290
0291
0292 if self.usePygresql:
0293 dsn = '%s:template1:%s:%s' % (
0294 self.host or '', self.user or '', self.password or '')
0295 else:
0296 dsn = 'dbname=template1'
0297 if self.user:
0298 dsn += ' user=%s' % self.user
0299 if self.password:
0300 dsn += ' password=%s' % self.password
0301 if self.host:
0302 dsn += ' host=%s' % self.host
0303 conn = self.module.connect(dsn)
0304 cur = conn.cursor()
0305
0306
0307 cur.execute('COMMIT')
0308
0309
0310
0311 cur.execute('CREATE DATABASE %s TEMPLATE=template0' % self.db)
0312 cur.close()
0313 conn.close()
0314
0315
0316
0317
0318def PsycoBinaryConverter(value, db):
0319 assert db == 'postgres'
0320 return str(value)