diff rhodecode/lib/dbmigrate/migrate/changeset/ansisql.py @ 833:9753e0907827 beta

added dbmigrate package, added model changes moved out upgrade db command to that package
author Marcin Kuzminski <marcin@python-works.com>
date Sat, 11 Dec 2010 01:54:12 +0100
parents
children 08d2dcd71666
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/rhodecode/lib/dbmigrate/migrate/changeset/ansisql.py	Sat Dec 11 01:54:12 2010 +0100
@@ -0,0 +1,358 @@
+"""
+   Extensions to SQLAlchemy for altering existing tables.
+
+   At the moment, this isn't so much based off of ANSI as much as
+   things that just happen to work with multiple databases.
+"""
+import StringIO
+
+import sqlalchemy as sa
+from sqlalchemy.schema import SchemaVisitor
+from sqlalchemy.engine.default import DefaultDialect
+from sqlalchemy.sql import ClauseElement
+from sqlalchemy.schema import (ForeignKeyConstraint,
+                               PrimaryKeyConstraint,
+                               CheckConstraint,
+                               UniqueConstraint,
+                               Index)
+
+from migrate import exceptions
+from migrate.changeset import constraint, SQLA_06
+
+if not SQLA_06:
+    from sqlalchemy.sql.compiler import SchemaGenerator, SchemaDropper
+else:
+    from sqlalchemy.schema import AddConstraint, DropConstraint
+    from sqlalchemy.sql.compiler import DDLCompiler
+    SchemaGenerator = SchemaDropper = DDLCompiler
+
+
+class AlterTableVisitor(SchemaVisitor):
+    """Common operations for ``ALTER TABLE`` statements."""
+
+    if SQLA_06:
+        # engine.Compiler looks for .statement
+        # when it spawns off a new compiler
+        statement = ClauseElement()
+
+    def append(self, s):
+        """Append content to the SchemaIterator's query buffer."""
+
+        self.buffer.write(s)
+
+    def execute(self):
+        """Execute the contents of the SchemaIterator's buffer."""
+        try:
+            return self.connection.execute(self.buffer.getvalue())
+        finally:
+            self.buffer.truncate(0)
+
+    def __init__(self, dialect, connection, **kw):
+        self.connection = connection
+        self.buffer = StringIO.StringIO()
+        self.preparer = dialect.identifier_preparer
+        self.dialect = dialect
+
+    def traverse_single(self, elem):
+        ret = super(AlterTableVisitor, self).traverse_single(elem)
+        if ret:
+            # adapt to 0.6 which uses a string-returning
+            # object
+            self.append(" %s" % ret)
+            
+    def _to_table(self, param):
+        """Returns the table object for the given param object."""
+        if isinstance(param, (sa.Column, sa.Index, sa.schema.Constraint)):
+            ret = param.table
+        else:
+            ret = param
+        return ret
+
+    def start_alter_table(self, param):
+        """Returns the start of an ``ALTER TABLE`` SQL-Statement.
+
+        Use the param object to determine the table name and use it
+        for building the SQL statement.
+
+        :param param: object to determine the table from
+        :type param: :class:`sqlalchemy.Column`, :class:`sqlalchemy.Index`,
+          :class:`sqlalchemy.schema.Constraint`, :class:`sqlalchemy.Table`,
+          or string (table name)
+        """
+        table = self._to_table(param)
+        self.append('\nALTER TABLE %s ' % self.preparer.format_table(table))
+        return table
+
+
+class ANSIColumnGenerator(AlterTableVisitor, SchemaGenerator):
+    """Extends ansisql generator for column creation (alter table add col)"""
+
+    def visit_column(self, column):
+        """Create a column (table already exists).
+
+        :param column: column object
+        :type column: :class:`sqlalchemy.Column` instance
+        """
+        if column.default is not None:
+            self.traverse_single(column.default)
+
+        table = self.start_alter_table(column)
+        self.append("ADD ")
+        self.append(self.get_column_specification(column))
+
+        for cons in column.constraints:
+            self.traverse_single(cons)
+        self.execute()
+
+        # ALTER TABLE STATEMENTS
+
+        # add indexes and unique constraints
+        if column.index_name:
+            Index(column.index_name,column).create()
+        elif column.unique_name:
+            constraint.UniqueConstraint(column,
+                                        name=column.unique_name).create()
+
+        # SA bounds FK constraints to table, add manually
+        for fk in column.foreign_keys:
+            self.add_foreignkey(fk.constraint)
+
+        # add primary key constraint if needed
+        if column.primary_key_name:
+            cons = constraint.PrimaryKeyConstraint(column,
+                                                   name=column.primary_key_name)
+            cons.create()
+
+    if SQLA_06:
+        def add_foreignkey(self, fk):
+            self.connection.execute(AddConstraint(fk))
+
+class ANSIColumnDropper(AlterTableVisitor, SchemaDropper):
+    """Extends ANSI SQL dropper for column dropping (``ALTER TABLE
+    DROP COLUMN``).
+    """
+
+    def visit_column(self, column):
+        """Drop a column from its table.
+
+        :param column: the column object
+        :type column: :class:`sqlalchemy.Column`
+        """
+        table = self.start_alter_table(column)
+        self.append('DROP COLUMN %s' % self.preparer.format_column(column))
+        self.execute()
+
+
+class ANSISchemaChanger(AlterTableVisitor, SchemaGenerator):
+    """Manages changes to existing schema elements.
+
+    Note that columns are schema elements; ``ALTER TABLE ADD COLUMN``
+    is in SchemaGenerator.
+
+    All items may be renamed. Columns can also have many of their properties -
+    type, for example - changed.
+
+    Each function is passed a tuple, containing (object, name); where
+    object is a type of object you'd expect for that function
+    (ie. table for visit_table) and name is the object's new
+    name. NONE means the name is unchanged.
+    """
+
+    def visit_table(self, table):
+        """Rename a table. Other ops aren't supported."""
+        self.start_alter_table(table)
+        self.append("RENAME TO %s" % self.preparer.quote(table.new_name,
+                                                         table.quote))
+        self.execute()
+
+    def visit_index(self, index):
+        """Rename an index"""
+        if hasattr(self, '_validate_identifier'):
+            # SA <= 0.6.3
+            self.append("ALTER INDEX %s RENAME TO %s" % (
+                    self.preparer.quote(
+                        self._validate_identifier(
+                            index.name, True), index.quote),
+                    self.preparer.quote(
+                        self._validate_identifier(
+                            index.new_name, True), index.quote)))
+        else:
+            # SA >= 0.6.5
+            self.append("ALTER INDEX %s RENAME TO %s" % (
+                    self.preparer.quote(
+                        self._index_identifier(
+                            index.name), index.quote),
+                    self.preparer.quote(
+                        self._index_identifier(
+                            index.new_name), index.quote)))
+        self.execute()
+
+    def visit_column(self, delta):
+        """Rename/change a column."""
+        # ALTER COLUMN is implemented as several ALTER statements
+        keys = delta.keys()
+        if 'type' in keys:
+            self._run_subvisit(delta, self._visit_column_type)
+        if 'nullable' in keys:
+            self._run_subvisit(delta, self._visit_column_nullable)
+        if 'server_default' in keys:
+            # Skip 'default': only handle server-side defaults, others
+            # are managed by the app, not the db.
+            self._run_subvisit(delta, self._visit_column_default)
+        if 'name' in keys:
+            self._run_subvisit(delta, self._visit_column_name, start_alter=False)
+
+    def _run_subvisit(self, delta, func, start_alter=True):
+        """Runs visit method based on what needs to be changed on column"""
+        table = self._to_table(delta.table)
+        col_name = delta.current_name
+        if start_alter:
+            self.start_alter_column(table, col_name)
+        ret = func(table, delta.result_column, delta)
+        self.execute()
+
+    def start_alter_column(self, table, col_name):
+        """Starts ALTER COLUMN"""
+        self.start_alter_table(table)
+        self.append("ALTER COLUMN %s " % self.preparer.quote(col_name, table.quote))
+
+    def _visit_column_nullable(self, table, column, delta):
+        nullable = delta['nullable']
+        if nullable:
+            self.append("DROP NOT NULL")
+        else:
+            self.append("SET NOT NULL")
+
+    def _visit_column_default(self, table, column, delta):
+        default_text = self.get_column_default_string(column)
+        if default_text is not None:
+            self.append("SET DEFAULT %s" % default_text)
+        else:
+            self.append("DROP DEFAULT")
+
+    def _visit_column_type(self, table, column, delta):
+        type_ = delta['type']
+        if SQLA_06:
+            type_text = str(type_.compile(dialect=self.dialect))
+        else:
+            type_text = type_.dialect_impl(self.dialect).get_col_spec()
+        self.append("TYPE %s" % type_text)
+
+    def _visit_column_name(self, table, column, delta):
+        self.start_alter_table(table)
+        col_name = self.preparer.quote(delta.current_name, table.quote)
+        new_name = self.preparer.format_column(delta.result_column)
+        self.append('RENAME COLUMN %s TO %s' % (col_name, new_name))
+
+
+class ANSIConstraintCommon(AlterTableVisitor):
+    """
+    Migrate's constraints require a separate creation function from
+    SA's: Migrate's constraints are created independently of a table;
+    SA's are created at the same time as the table.
+    """
+
+    def get_constraint_name(self, cons):
+        """Gets a name for the given constraint.
+
+        If the name is already set it will be used otherwise the
+        constraint's :meth:`autoname <migrate.changeset.constraint.ConstraintChangeset.autoname>`
+        method is used.
+
+        :param cons: constraint object
+        """
+        if cons.name is not None:
+            ret = cons.name
+        else:
+            ret = cons.name = cons.autoname()
+        return self.preparer.quote(ret, cons.quote)
+
+    def visit_migrate_primary_key_constraint(self, *p, **k):
+        self._visit_constraint(*p, **k)
+
+    def visit_migrate_foreign_key_constraint(self, *p, **k):
+        self._visit_constraint(*p, **k)
+
+    def visit_migrate_check_constraint(self, *p, **k):
+        self._visit_constraint(*p, **k)
+
+    def visit_migrate_unique_constraint(self, *p, **k):
+        self._visit_constraint(*p, **k)
+
+if SQLA_06:
+    class ANSIConstraintGenerator(ANSIConstraintCommon, SchemaGenerator):
+        def _visit_constraint(self, constraint):
+            constraint.name = self.get_constraint_name(constraint)
+            self.append(self.process(AddConstraint(constraint)))
+            self.execute()
+
+    class ANSIConstraintDropper(ANSIConstraintCommon, SchemaDropper):
+        def _visit_constraint(self, constraint):
+            constraint.name = self.get_constraint_name(constraint)
+            self.append(self.process(DropConstraint(constraint, cascade=constraint.cascade)))
+            self.execute()
+
+else:
+    class ANSIConstraintGenerator(ANSIConstraintCommon, SchemaGenerator):
+
+        def get_constraint_specification(self, cons, **kwargs):
+            """Constaint SQL generators.
+        
+            We cannot use SA visitors because they append comma.
+            """
+        
+            if isinstance(cons, PrimaryKeyConstraint):
+                if cons.name is not None:
+                    self.append("CONSTRAINT %s " % self.preparer.format_constraint(cons))
+                self.append("PRIMARY KEY ")
+                self.append("(%s)" % ', '.join(self.preparer.quote(c.name, c.quote)
+                                               for c in cons))
+                self.define_constraint_deferrability(cons)
+            elif isinstance(cons, ForeignKeyConstraint):
+                self.define_foreign_key(cons)
+            elif isinstance(cons, CheckConstraint):
+                if cons.name is not None:
+                    self.append("CONSTRAINT %s " %
+                                self.preparer.format_constraint(cons))
+                self.append("CHECK (%s)" % cons.sqltext)
+                self.define_constraint_deferrability(cons)
+            elif isinstance(cons, UniqueConstraint):
+                if cons.name is not None:
+                    self.append("CONSTRAINT %s " %
+                                self.preparer.format_constraint(cons))
+                self.append("UNIQUE (%s)" % \
+                    (', '.join(self.preparer.quote(c.name, c.quote) for c in cons)))
+                self.define_constraint_deferrability(cons)
+            else:
+                raise exceptions.InvalidConstraintError(cons)
+
+        def _visit_constraint(self, constraint):
+        
+            table = self.start_alter_table(constraint)
+            constraint.name = self.get_constraint_name(constraint)
+            self.append("ADD ")
+            self.get_constraint_specification(constraint)
+            self.execute()
+    
+
+    class ANSIConstraintDropper(ANSIConstraintCommon, SchemaDropper):
+
+        def _visit_constraint(self, constraint):
+            self.start_alter_table(constraint)
+            self.append("DROP CONSTRAINT ")
+            constraint.name = self.get_constraint_name(constraint)
+            self.append(self.preparer.format_constraint(constraint))
+            if constraint.cascade:
+                self.cascade_constraint(constraint)
+            self.execute()
+
+        def cascade_constraint(self, constraint):
+            self.append(" CASCADE")
+
+
+class ANSIDialect(DefaultDialect):
+    columngenerator = ANSIColumnGenerator
+    columndropper = ANSIColumnDropper
+    schemachanger = ANSISchemaChanger
+    constraintgenerator = ANSIConstraintGenerator
+    constraintdropper = ANSIConstraintDropper