comparison rhodecode/lib/dbmigrate/migrate/versioning/genmodel.py @ 833:9753e0907827 beta

added dbmigrate package, added model changes moved out upgrade db command to that package
author Marcin Kuzminski <marcin@python-works.com>
date Sat, 11 Dec 2010 01:54:12 +0100
parents
children 08d2dcd71666
comparison
equal deleted inserted replaced
832:634596f81cfd 833:9753e0907827
1 """
2 Code to generate a Python model from a database or differences
3 between a model and database.
4
5 Some of this is borrowed heavily from the AutoCode project at:
6 http://code.google.com/p/sqlautocode/
7 """
8
9 import sys
10 import logging
11
12 import sqlalchemy
13
14 import migrate
15 import migrate.changeset
16
17
18 log = logging.getLogger(__name__)
19 HEADER = """
20 ## File autogenerated by genmodel.py
21
22 from sqlalchemy import *
23 meta = MetaData()
24 """
25
26 DECLARATIVE_HEADER = """
27 ## File autogenerated by genmodel.py
28
29 from sqlalchemy import *
30 from sqlalchemy.ext import declarative
31
32 Base = declarative.declarative_base()
33 """
34
35
36 class ModelGenerator(object):
37
38 def __init__(self, diff, engine, declarative=False):
39 self.diff = diff
40 self.engine = engine
41 self.declarative = declarative
42
43 def column_repr(self, col):
44 kwarg = []
45 if col.key != col.name:
46 kwarg.append('key')
47 if col.primary_key:
48 col.primary_key = True # otherwise it dumps it as 1
49 kwarg.append('primary_key')
50 if not col.nullable:
51 kwarg.append('nullable')
52 if col.onupdate:
53 kwarg.append('onupdate')
54 if col.default:
55 if col.primary_key:
56 # I found that PostgreSQL automatically creates a
57 # default value for the sequence, but let's not show
58 # that.
59 pass
60 else:
61 kwarg.append('default')
62 ks = ', '.join('%s=%r' % (k, getattr(col, k)) for k in kwarg)
63
64 # crs: not sure if this is good idea, but it gets rid of extra
65 # u''
66 name = col.name.encode('utf8')
67
68 type_ = col.type
69 for cls in col.type.__class__.__mro__:
70 if cls.__module__ == 'sqlalchemy.types' and \
71 not cls.__name__.isupper():
72 if cls is not type_.__class__:
73 type_ = cls()
74 break
75
76 data = {
77 'name': name,
78 'type': type_,
79 'constraints': ', '.join([repr(cn) for cn in col.constraints]),
80 'args': ks and ks or ''}
81
82 if data['constraints']:
83 if data['args']:
84 data['args'] = ',' + data['args']
85
86 if data['constraints'] or data['args']:
87 data['maybeComma'] = ','
88 else:
89 data['maybeComma'] = ''
90
91 commonStuff = """ %(maybeComma)s %(constraints)s %(args)s)""" % data
92 commonStuff = commonStuff.strip()
93 data['commonStuff'] = commonStuff
94 if self.declarative:
95 return """%(name)s = Column(%(type)r%(commonStuff)s""" % data
96 else:
97 return """Column(%(name)r, %(type)r%(commonStuff)s""" % data
98
99 def getTableDefn(self, table):
100 out = []
101 tableName = table.name
102 if self.declarative:
103 out.append("class %(table)s(Base):" % {'table': tableName})
104 out.append(" __tablename__ = '%(table)s'" % {'table': tableName})
105 for col in table.columns:
106 out.append(" %s" % self.column_repr(col))
107 else:
108 out.append("%(table)s = Table('%(table)s', meta," % \
109 {'table': tableName})
110 for col in table.columns:
111 out.append(" %s," % self.column_repr(col))
112 out.append(")")
113 return out
114
115 def _get_tables(self,missingA=False,missingB=False,modified=False):
116 to_process = []
117 for bool_,names,metadata in (
118 (missingA,self.diff.tables_missing_from_A,self.diff.metadataB),
119 (missingB,self.diff.tables_missing_from_B,self.diff.metadataA),
120 (modified,self.diff.tables_different,self.diff.metadataA),
121 ):
122 if bool_:
123 for name in names:
124 yield metadata.tables.get(name)
125
126 def toPython(self):
127 """Assume database is current and model is empty."""
128 out = []
129 if self.declarative:
130 out.append(DECLARATIVE_HEADER)
131 else:
132 out.append(HEADER)
133 out.append("")
134 for table in self._get_tables(missingA=True):
135 out.extend(self.getTableDefn(table))
136 out.append("")
137 return '\n'.join(out)
138
139 def toUpgradeDowngradePython(self, indent=' '):
140 ''' Assume model is most current and database is out-of-date. '''
141 decls = ['from migrate.changeset import schema',
142 'meta = MetaData()']
143 for table in self._get_tables(
144 missingA=True,missingB=True,modified=True
145 ):
146 decls.extend(self.getTableDefn(table))
147
148 upgradeCommands, downgradeCommands = [], []
149 for tableName in self.diff.tables_missing_from_A:
150 upgradeCommands.append("%(table)s.drop()" % {'table': tableName})
151 downgradeCommands.append("%(table)s.create()" % \
152 {'table': tableName})
153 for tableName in self.diff.tables_missing_from_B:
154 upgradeCommands.append("%(table)s.create()" % {'table': tableName})
155 downgradeCommands.append("%(table)s.drop()" % {'table': tableName})
156
157 for tableName in self.diff.tables_different:
158 dbTable = self.diff.metadataB.tables[tableName]
159 missingInDatabase, missingInModel, diffDecl = \
160 self.diff.colDiffs[tableName]
161 for col in missingInDatabase:
162 upgradeCommands.append('%s.columns[%r].create()' % (
163 modelTable, col.name))
164 downgradeCommands.append('%s.columns[%r].drop()' % (
165 modelTable, col.name))
166 for col in missingInModel:
167 upgradeCommands.append('%s.columns[%r].drop()' % (
168 modelTable, col.name))
169 downgradeCommands.append('%s.columns[%r].create()' % (
170 modelTable, col.name))
171 for modelCol, databaseCol, modelDecl, databaseDecl in diffDecl:
172 upgradeCommands.append(
173 'assert False, "Can\'t alter columns: %s:%s=>%s"',
174 modelTable, modelCol.name, databaseCol.name)
175 downgradeCommands.append(
176 'assert False, "Can\'t alter columns: %s:%s=>%s"',
177 modelTable, modelCol.name, databaseCol.name)
178 pre_command = ' meta.bind = migrate_engine'
179
180 return (
181 '\n'.join(decls),
182 '\n'.join([pre_command] + ['%s%s' % (indent, line) for line in upgradeCommands]),
183 '\n'.join([pre_command] + ['%s%s' % (indent, line) for line in downgradeCommands]))
184
185 def _db_can_handle_this_change(self,td):
186 if (td.columns_missing_from_B
187 and not td.columns_missing_from_A
188 and not td.columns_different):
189 # Even sqlite can handle this.
190 return True
191 else:
192 return not self.engine.url.drivername.startswith('sqlite')
193
194 def applyModel(self):
195 """Apply model to current database."""
196
197 meta = sqlalchemy.MetaData(self.engine)
198
199 for table in self._get_tables(missingA=True):
200 table = table.tometadata(meta)
201 table.drop()
202 for table in self._get_tables(missingB=True):
203 table = table.tometadata(meta)
204 table.create()
205 for modelTable in self._get_tables(modified=True):
206 tableName = modelTable.name
207 modelTable = modelTable.tometadata(meta)
208 dbTable = self.diff.metadataB.tables[tableName]
209
210 td = self.diff.tables_different[tableName]
211
212 if self._db_can_handle_this_change(td):
213
214 for col in td.columns_missing_from_B:
215 modelTable.columns[col].create()
216 for col in td.columns_missing_from_A:
217 dbTable.columns[col].drop()
218 # XXX handle column changes here.
219 else:
220 # Sqlite doesn't support drop column, so you have to
221 # do more: create temp table, copy data to it, drop
222 # old table, create new table, copy data back.
223 #
224 # I wonder if this is guaranteed to be unique?
225 tempName = '_temp_%s' % modelTable.name
226
227 def getCopyStatement():
228 preparer = self.engine.dialect.preparer
229 commonCols = []
230 for modelCol in modelTable.columns:
231 if modelCol.name in dbTable.columns:
232 commonCols.append(modelCol.name)
233 commonColsStr = ', '.join(commonCols)
234 return 'INSERT INTO %s (%s) SELECT %s FROM %s' % \
235 (tableName, commonColsStr, commonColsStr, tempName)
236
237 # Move the data in one transaction, so that we don't
238 # leave the database in a nasty state.
239 connection = self.engine.connect()
240 trans = connection.begin()
241 try:
242 connection.execute(
243 'CREATE TEMPORARY TABLE %s as SELECT * from %s' % \
244 (tempName, modelTable.name))
245 # make sure the drop takes place inside our
246 # transaction with the bind parameter
247 modelTable.drop(bind=connection)
248 modelTable.create(bind=connection)
249 connection.execute(getCopyStatement())
250 connection.execute('DROP TABLE %s' % tempName)
251 trans.commit()
252 except:
253 trans.rollback()
254 raise