#!/usr/bin/env python3 """ Database Migration Generator Generates safe migration scripts between schema versions: - Compares current and target schemas - Generates ALTER TABLE statements for schema changes - Implements zero-downtime migration strategies (expand-contract pattern) - Creates rollback scripts for all changes - Generates validation queries to verify migrations - Handles complex changes like table splits/merges Input: Current schema JSON + Target schema JSON Output: Migration SQL + Rollback SQL + Validation queries + Execution plan Usage: python migration_generator.py --current current_schema.json --target target_schema.json --output migration.sql python migration_generator.py --current current.json --target target.json --format json python migration_generator.py --current current.json --target target.json --zero-downtime python migration_generator.py --current current.json --target target.json --validate-only """ import argparse import json import re import sys from collections import defaultdict, OrderedDict from typing import Dict, List, Set, Tuple, Optional, Any, Union from dataclasses import dataclass, asdict from datetime import datetime import hashlib @dataclass class Column: name: str data_type: str nullable: bool = True primary_key: bool = False unique: bool = False foreign_key: Optional[str] = None default_value: Optional[str] = None check_constraint: Optional[str] = None @dataclass class Table: name: str columns: Dict[str, Column] primary_key: List[str] foreign_keys: Dict[str, str] # column -> referenced_table.column unique_constraints: List[List[str]] check_constraints: Dict[str, str] indexes: List[Dict[str, Any]] @dataclass class MigrationStep: step_id: str step_type: str table: str description: str sql_forward: str sql_rollback: str validation_sql: Optional[str] = None dependencies: List[str] = None risk_level: str = "LOW" # LOW, MEDIUM, HIGH estimated_time: Optional[str] = None zero_downtime_phase: Optional[str] = None # EXPAND, CONTRACT, or None @dataclass class MigrationPlan: migration_id: str created_at: str source_schema_hash: str target_schema_hash: str steps: List[MigrationStep] summary: Dict[str, Any] execution_order: List[str] rollback_order: List[str] @dataclass class ValidationCheck: check_id: str check_type: str table: str description: str sql_query: str expected_result: Any critical: bool = True class SchemaComparator: """Compares two schema versions and identifies differences.""" def __init__(self): self.current_schema: Dict[str, Table] = {} self.target_schema: Dict[str, Table] = {} self.changes: Dict[str, List[Dict[str, Any]]] = { 'tables_added': [], 'tables_dropped': [], 'tables_renamed': [], 'columns_added': [], 'columns_dropped': [], 'columns_modified': [], 'columns_renamed': [], 'constraints_added': [], 'constraints_dropped': [], 'indexes_added': [], 'indexes_dropped': [] } def load_schemas(self, current_data: Dict[str, Any], target_data: Dict[str, Any]): """Load current and target schemas.""" self.current_schema = self._parse_schema(current_data) self.target_schema = self._parse_schema(target_data) def _parse_schema(self, schema_data: Dict[str, Any]) -> Dict[str, Table]: """Parse schema JSON into Table objects.""" tables = {} if 'tables' not in schema_data: return tables for table_name, table_def in schema_data['tables'].items(): columns = {} primary_key = table_def.get('primary_key', []) foreign_keys = {} # Parse columns for col_name, col_def in table_def.get('columns', {}).items(): column = Column( name=col_name, data_type=col_def.get('type', 'VARCHAR(255)'), nullable=col_def.get('nullable', True), primary_key=col_name in primary_key, unique=col_def.get('unique', False), foreign_key=col_def.get('foreign_key'), default_value=col_def.get('default'), check_constraint=col_def.get('check_constraint') ) columns[col_name] = column if column.foreign_key: foreign_keys[col_name] = column.foreign_key table = Table( name=table_name, columns=columns, primary_key=primary_key, foreign_keys=foreign_keys, unique_constraints=table_def.get('unique_constraints', []), check_constraints=table_def.get('check_constraints', {}), indexes=table_def.get('indexes', []) ) tables[table_name] = table return tables def compare_schemas(self) -> Dict[str, List[Dict[str, Any]]]: """Compare schemas and identify all changes.""" self._compare_tables() self._compare_columns() self._compare_constraints() self._compare_indexes() return self.changes def _compare_tables(self): """Compare table-level changes.""" current_tables = set(self.current_schema.keys()) target_tables = set(self.target_schema.keys()) # Tables added for table_name in target_tables - current_tables: self.changes['tables_added'].append({ 'table': table_name, 'definition': self.target_schema[table_name] }) # Tables dropped for table_name in current_tables - target_tables: self.changes['tables_dropped'].append({ 'table': table_name, 'definition': self.current_schema[table_name] }) # Tables renamed (heuristic based on column similarity) self._detect_renamed_tables(current_tables - target_tables, target_tables - current_tables) def _detect_renamed_tables(self, dropped_tables: Set[str], added_tables: Set[str]): """Detect renamed tables based on column similarity.""" if not dropped_tables or not added_tables: return # Calculate similarity scores similarity_scores = [] for dropped_table in dropped_tables: for added_table in added_tables: score = self._calculate_table_similarity(dropped_table, added_table) if score > 0.7: # High similarity threshold similarity_scores.append((score, dropped_table, added_table)) # Sort by similarity and identify renames similarity_scores.sort(reverse=True) used_tables = set() for score, old_name, new_name in similarity_scores: if old_name not in used_tables and new_name not in used_tables: self.changes['tables_renamed'].append({ 'old_name': old_name, 'new_name': new_name, 'similarity_score': score }) used_tables.add(old_name) used_tables.add(new_name) # Remove from added/dropped lists self.changes['tables_added'] = [t for t in self.changes['tables_added'] if t['table'] != new_name] self.changes['tables_dropped'] = [t for t in self.changes['tables_dropped'] if t['table'] != old_name] def _calculate_table_similarity(self, table1_name: str, table2_name: str) -> float: """Calculate similarity between two tables based on columns.""" table1 = self.current_schema[table1_name] table2 = self.target_schema[table2_name] cols1 = set(table1.columns.keys()) cols2 = set(table2.columns.keys()) if not cols1 and not cols2: return 1.0 elif not cols1 or not cols2: return 0.0 intersection = len(cols1.intersection(cols2)) union = len(cols1.union(cols2)) return intersection / union def _compare_columns(self): """Compare column-level changes.""" common_tables = set(self.current_schema.keys()).intersection(set(self.target_schema.keys())) for table_name in common_tables: current_table = self.current_schema[table_name] target_table = self.target_schema[table_name] current_columns = set(current_table.columns.keys()) target_columns = set(target_table.columns.keys()) # Columns added for col_name in target_columns - current_columns: self.changes['columns_added'].append({ 'table': table_name, 'column': col_name, 'definition': target_table.columns[col_name] }) # Columns dropped for col_name in current_columns - target_columns: self.changes['columns_dropped'].append({ 'table': table_name, 'column': col_name, 'definition': current_table.columns[col_name] }) # Columns modified for col_name in current_columns.intersection(target_columns): current_col = current_table.columns[col_name] target_col = target_table.columns[col_name] if self._columns_different(current_col, target_col): self.changes['columns_modified'].append({ 'table': table_name, 'column': col_name, 'current_definition': current_col, 'target_definition': target_col, 'changes': self._describe_column_changes(current_col, target_col) }) def _columns_different(self, col1: Column, col2: Column) -> bool: """Check if two columns have different definitions.""" return (col1.data_type != col2.data_type or col1.nullable != col2.nullable or col1.default_value != col2.default_value or col1.unique != col2.unique or col1.foreign_key != col2.foreign_key or col1.check_constraint != col2.check_constraint) def _describe_column_changes(self, current_col: Column, target_col: Column) -> List[str]: """Describe specific changes between column definitions.""" changes = [] if current_col.data_type != target_col.data_type: changes.append(f"type: {current_col.data_type} -> {target_col.data_type}") if current_col.nullable != target_col.nullable: changes.append(f"nullable: {current_col.nullable} -> {target_col.nullable}") if current_col.default_value != target_col.default_value: changes.append(f"default: {current_col.default_value} -> {target_col.default_value}") if current_col.unique != target_col.unique: changes.append(f"unique: {current_col.unique} -> {target_col.unique}") if current_col.foreign_key != target_col.foreign_key: changes.append(f"foreign_key: {current_col.foreign_key} -> {target_col.foreign_key}") return changes def _compare_constraints(self): """Compare constraint changes.""" common_tables = set(self.current_schema.keys()).intersection(set(self.target_schema.keys())) for table_name in common_tables: current_table = self.current_schema[table_name] target_table = self.target_schema[table_name] # Compare primary keys if current_table.primary_key != target_table.primary_key: if current_table.primary_key: self.changes['constraints_dropped'].append({ 'table': table_name, 'constraint_type': 'PRIMARY_KEY', 'columns': current_table.primary_key }) if target_table.primary_key: self.changes['constraints_added'].append({ 'table': table_name, 'constraint_type': 'PRIMARY_KEY', 'columns': target_table.primary_key }) # Compare unique constraints current_unique = set(tuple(uc) for uc in current_table.unique_constraints) target_unique = set(tuple(uc) for uc in target_table.unique_constraints) for constraint in target_unique - current_unique: self.changes['constraints_added'].append({ 'table': table_name, 'constraint_type': 'UNIQUE', 'columns': list(constraint) }) for constraint in current_unique - target_unique: self.changes['constraints_dropped'].append({ 'table': table_name, 'constraint_type': 'UNIQUE', 'columns': list(constraint) }) # Compare check constraints current_checks = set(current_table.check_constraints.items()) target_checks = set(target_table.check_constraints.items()) for name, condition in target_checks - current_checks: self.changes['constraints_added'].append({ 'table': table_name, 'constraint_type': 'CHECK', 'constraint_name': name, 'condition': condition }) for name, condition in current_checks - target_checks: self.changes['constraints_dropped'].append({ 'table': table_name, 'constraint_type': 'CHECK', 'constraint_name': name, 'condition': condition }) def _compare_indexes(self): """Compare index changes.""" common_tables = set(self.current_schema.keys()).intersection(set(self.target_schema.keys())) for table_name in common_tables: current_indexes = {idx['name']: idx for idx in self.current_schema[table_name].indexes} target_indexes = {idx['name']: idx for idx in self.target_schema[table_name].indexes} current_names = set(current_indexes.keys()) target_names = set(target_indexes.keys()) # Indexes added for idx_name in target_names - current_names: self.changes['indexes_added'].append({ 'table': table_name, 'index': target_indexes[idx_name] }) # Indexes dropped for idx_name in current_names - target_names: self.changes['indexes_dropped'].append({ 'table': table_name, 'index': current_indexes[idx_name] }) class MigrationGenerator: """Generates migration steps from schema differences.""" def __init__(self, zero_downtime: bool = False): self.zero_downtime = zero_downtime self.migration_steps: List[MigrationStep] = [] self.step_counter = 0 # Data type conversion safety self.safe_type_conversions = { ('VARCHAR(50)', 'VARCHAR(100)'): True, # Expanding varchar ('INT', 'BIGINT'): True, # Expanding integer ('DECIMAL(10,2)', 'DECIMAL(12,2)'): True, # Expanding decimal precision } self.risky_type_conversions = { ('VARCHAR(100)', 'VARCHAR(50)'): 'Data truncation possible', ('BIGINT', 'INT'): 'Data loss possible for large values', ('TEXT', 'VARCHAR(255)'): 'Data truncation possible' } def generate_migration(self, changes: Dict[str, List[Dict[str, Any]]]) -> MigrationPlan: """Generate complete migration plan from schema changes.""" self.migration_steps = [] self.step_counter = 0 # Generate steps in dependency order self._generate_table_creation_steps(changes['tables_added']) self._generate_column_addition_steps(changes['columns_added']) self._generate_constraint_addition_steps(changes['constraints_added']) self._generate_index_addition_steps(changes['indexes_added']) self._generate_column_modification_steps(changes['columns_modified']) self._generate_table_rename_steps(changes['tables_renamed']) self._generate_index_removal_steps(changes['indexes_dropped']) self._generate_constraint_removal_steps(changes['constraints_dropped']) self._generate_column_removal_steps(changes['columns_dropped']) self._generate_table_removal_steps(changes['tables_dropped']) # Create migration plan migration_id = self._generate_migration_id(changes) execution_order = [step.step_id for step in self.migration_steps] rollback_order = list(reversed(execution_order)) return MigrationPlan( migration_id=migration_id, created_at=datetime.now().isoformat(), source_schema_hash=self._calculate_changes_hash(changes), target_schema_hash="", # Would be calculated from target schema steps=self.migration_steps, summary=self._generate_summary(changes), execution_order=execution_order, rollback_order=rollback_order ) def _generate_step_id(self) -> str: """Generate unique step ID.""" self.step_counter += 1 return f"step_{self.step_counter:03d}" def _generate_table_creation_steps(self, tables_added: List[Dict[str, Any]]): """Generate steps for creating new tables.""" for table_info in tables_added: table = table_info['definition'] step = self._create_table_step(table) self.migration_steps.append(step) def _create_table_step(self, table: Table) -> MigrationStep: """Create migration step for table creation.""" columns_sql = [] for col_name, column in table.columns.items(): col_sql = f"{col_name} {column.data_type}" if not column.nullable: col_sql += " NOT NULL" if column.default_value: col_sql += f" DEFAULT {column.default_value}" if column.unique: col_sql += " UNIQUE" columns_sql.append(col_sql) # Add primary key if table.primary_key: pk_sql = f"PRIMARY KEY ({', '.join(table.primary_key)})" columns_sql.append(pk_sql) # Add foreign keys for col_name, ref in table.foreign_keys.items(): fk_sql = f"FOREIGN KEY ({col_name}) REFERENCES {ref}" columns_sql.append(fk_sql) create_sql = f"CREATE TABLE {table.name} (\n " + ",\n ".join(columns_sql) + "\n);" drop_sql = f"DROP TABLE IF EXISTS {table.name};" return MigrationStep( step_id=self._generate_step_id(), step_type="CREATE_TABLE", table=table.name, description=f"Create table {table.name} with {len(table.columns)} columns", sql_forward=create_sql, sql_rollback=drop_sql, validation_sql=f"SELECT COUNT(*) FROM information_schema.tables WHERE table_name = '{table.name}';", risk_level="LOW" ) def _generate_column_addition_steps(self, columns_added: List[Dict[str, Any]]): """Generate steps for adding columns.""" for col_info in columns_added: if self.zero_downtime: # For zero-downtime, add columns as nullable first step = self._add_column_zero_downtime_step(col_info) else: step = self._add_column_step(col_info) self.migration_steps.append(step) def _add_column_step(self, col_info: Dict[str, Any]) -> MigrationStep: """Create step for adding a column.""" table = col_info['table'] column = col_info['definition'] col_sql = f"{column.name} {column.data_type}" if not column.nullable: if column.default_value: col_sql += f" DEFAULT {column.default_value} NOT NULL" else: # This is risky - adding NOT NULL without default col_sql += " NOT NULL" elif column.default_value: col_sql += f" DEFAULT {column.default_value}" add_sql = f"ALTER TABLE {table} ADD COLUMN {col_sql};" drop_sql = f"ALTER TABLE {table} DROP COLUMN {column.name};" risk_level = "HIGH" if not column.nullable and not column.default_value else "LOW" return MigrationStep( step_id=self._generate_step_id(), step_type="ADD_COLUMN", table=table, description=f"Add column {column.name} to {table}", sql_forward=add_sql, sql_rollback=drop_sql, validation_sql=f"SELECT COUNT(*) FROM information_schema.columns WHERE table_name = '{table}' AND column_name = '{column.name}';", risk_level=risk_level ) def _add_column_zero_downtime_step(self, col_info: Dict[str, Any]) -> MigrationStep: """Create zero-downtime step for adding column.""" table = col_info['table'] column = col_info['definition'] # Phase 1: Add as nullable with default if needed col_sql = f"{column.name} {column.data_type}" if column.default_value: col_sql += f" DEFAULT {column.default_value}" add_sql = f"ALTER TABLE {table} ADD COLUMN {col_sql};" # If column should be NOT NULL, handle in separate phase if not column.nullable: # Add comment about needing follow-up step add_sql += f"\n-- Follow-up needed: Add NOT NULL constraint after data population" drop_sql = f"ALTER TABLE {table} DROP COLUMN {column.name};" return MigrationStep( step_id=self._generate_step_id(), step_type="ADD_COLUMN_ZD", table=table, description=f"Add column {column.name} to {table} (zero-downtime phase 1)", sql_forward=add_sql, sql_rollback=drop_sql, validation_sql=f"SELECT COUNT(*) FROM information_schema.columns WHERE table_name = '{table}' AND column_name = '{column.name}';", risk_level="LOW", zero_downtime_phase="EXPAND" ) def _generate_column_modification_steps(self, columns_modified: List[Dict[str, Any]]): """Generate steps for modifying columns.""" for col_info in columns_modified: if self.zero_downtime: steps = self._modify_column_zero_downtime_steps(col_info) self.migration_steps.extend(steps) else: step = self._modify_column_step(col_info) self.migration_steps.append(step) def _modify_column_step(self, col_info: Dict[str, Any]) -> MigrationStep: """Create step for modifying a column.""" table = col_info['table'] column = col_info['column'] current_def = col_info['current_definition'] target_def = col_info['target_definition'] changes = col_info['changes'] alter_statements = [] rollback_statements = [] # Handle different types of changes if current_def.data_type != target_def.data_type: alter_statements.append(f"ALTER COLUMN {column} TYPE {target_def.data_type}") rollback_statements.append(f"ALTER COLUMN {column} TYPE {current_def.data_type}") if current_def.nullable != target_def.nullable: if target_def.nullable: alter_statements.append(f"ALTER COLUMN {column} DROP NOT NULL") rollback_statements.append(f"ALTER COLUMN {column} SET NOT NULL") else: alter_statements.append(f"ALTER COLUMN {column} SET NOT NULL") rollback_statements.append(f"ALTER COLUMN {column} DROP NOT NULL") if current_def.default_value != target_def.default_value: if target_def.default_value: alter_statements.append(f"ALTER COLUMN {column} SET DEFAULT {target_def.default_value}") else: alter_statements.append(f"ALTER COLUMN {column} DROP DEFAULT") if current_def.default_value: rollback_statements.append(f"ALTER COLUMN {column} SET DEFAULT {current_def.default_value}") else: rollback_statements.append(f"ALTER COLUMN {column} DROP DEFAULT") # Build SQL alter_sql = f"ALTER TABLE {table}\n " + ",\n ".join(alter_statements) + ";" rollback_sql = f"ALTER TABLE {table}\n " + ",\n ".join(rollback_statements) + ";" # Assess risk risk_level = self._assess_column_modification_risk(current_def, target_def) return MigrationStep( step_id=self._generate_step_id(), step_type="MODIFY_COLUMN", table=table, description=f"Modify column {column}: {', '.join(changes)}", sql_forward=alter_sql, sql_rollback=rollback_sql, validation_sql=f"SELECT data_type, is_nullable FROM information_schema.columns WHERE table_name = '{table}' AND column_name = '{column}';", risk_level=risk_level ) def _modify_column_zero_downtime_steps(self, col_info: Dict[str, Any]) -> List[MigrationStep]: """Create zero-downtime steps for column modification.""" table = col_info['table'] column = col_info['column'] current_def = col_info['current_definition'] target_def = col_info['target_definition'] steps = [] # For zero-downtime, use expand-contract pattern temp_column = f"{column}_new" # Step 1: Add new column step1 = MigrationStep( step_id=self._generate_step_id(), step_type="ADD_TEMP_COLUMN", table=table, description=f"Add temporary column {temp_column} for zero-downtime migration", sql_forward=f"ALTER TABLE {table} ADD COLUMN {temp_column} {target_def.data_type};", sql_rollback=f"ALTER TABLE {table} DROP COLUMN {temp_column};", zero_downtime_phase="EXPAND" ) steps.append(step1) # Step 2: Copy data step2 = MigrationStep( step_id=self._generate_step_id(), step_type="COPY_COLUMN_DATA", table=table, description=f"Copy data from {column} to {temp_column}", sql_forward=f"UPDATE {table} SET {temp_column} = {column};", sql_rollback=f"UPDATE {table} SET {temp_column} = NULL;", zero_downtime_phase="EXPAND" ) steps.append(step2) # Step 3: Drop old column step3 = MigrationStep( step_id=self._generate_step_id(), step_type="DROP_OLD_COLUMN", table=table, description=f"Drop original column {column}", sql_forward=f"ALTER TABLE {table} DROP COLUMN {column};", sql_rollback=f"ALTER TABLE {table} ADD COLUMN {column} {current_def.data_type};", zero_downtime_phase="CONTRACT" ) steps.append(step3) # Step 4: Rename new column step4 = MigrationStep( step_id=self._generate_step_id(), step_type="RENAME_COLUMN", table=table, description=f"Rename {temp_column} to {column}", sql_forward=f"ALTER TABLE {table} RENAME COLUMN {temp_column} TO {column};", sql_rollback=f"ALTER TABLE {table} RENAME COLUMN {column} TO {temp_column};", zero_downtime_phase="CONTRACT" ) steps.append(step4) return steps def _assess_column_modification_risk(self, current: Column, target: Column) -> str: """Assess risk level of column modification.""" if current.data_type != target.data_type: conversion_key = (current.data_type, target.data_type) if conversion_key in self.risky_type_conversions: return "HIGH" elif conversion_key not in self.safe_type_conversions: return "MEDIUM" if current.nullable and not target.nullable: return "HIGH" # Adding NOT NULL constraint return "LOW" def _generate_constraint_addition_steps(self, constraints_added: List[Dict[str, Any]]): """Generate steps for adding constraints.""" for constraint_info in constraints_added: step = self._add_constraint_step(constraint_info) self.migration_steps.append(step) def _add_constraint_step(self, constraint_info: Dict[str, Any]) -> MigrationStep: """Create step for adding constraint.""" table = constraint_info['table'] constraint_type = constraint_info['constraint_type'] if constraint_type == 'PRIMARY_KEY': columns = constraint_info['columns'] constraint_name = f"pk_{table}" add_sql = f"ALTER TABLE {table} ADD CONSTRAINT {constraint_name} PRIMARY KEY ({', '.join(columns)});" drop_sql = f"ALTER TABLE {table} DROP CONSTRAINT {constraint_name};" description = f"Add primary key on {', '.join(columns)}" elif constraint_type == 'UNIQUE': columns = constraint_info['columns'] constraint_name = f"uq_{table}_{'_'.join(columns)}" add_sql = f"ALTER TABLE {table} ADD CONSTRAINT {constraint_name} UNIQUE ({', '.join(columns)});" drop_sql = f"ALTER TABLE {table} DROP CONSTRAINT {constraint_name};" description = f"Add unique constraint on {', '.join(columns)}" elif constraint_type == 'CHECK': constraint_name = constraint_info['constraint_name'] condition = constraint_info['condition'] add_sql = f"ALTER TABLE {table} ADD CONSTRAINT {constraint_name} CHECK ({condition});" drop_sql = f"ALTER TABLE {table} DROP CONSTRAINT {constraint_name};" description = f"Add check constraint: {condition}" else: return None return MigrationStep( step_id=self._generate_step_id(), step_type="ADD_CONSTRAINT", table=table, description=description, sql_forward=add_sql, sql_rollback=drop_sql, risk_level="MEDIUM" # Constraints can fail if data doesn't comply ) def _generate_index_addition_steps(self, indexes_added: List[Dict[str, Any]]): """Generate steps for adding indexes.""" for index_info in indexes_added: step = self._add_index_step(index_info) self.migration_steps.append(step) def _add_index_step(self, index_info: Dict[str, Any]) -> MigrationStep: """Create step for adding index.""" table = index_info['table'] index = index_info['index'] unique_keyword = "UNIQUE " if index.get('unique', False) else "" columns_sql = ', '.join(index['columns']) create_sql = f"CREATE {unique_keyword}INDEX {index['name']} ON {table} ({columns_sql});" drop_sql = f"DROP INDEX {index['name']};" return MigrationStep( step_id=self._generate_step_id(), step_type="ADD_INDEX", table=table, description=f"Create index {index['name']} on ({columns_sql})", sql_forward=create_sql, sql_rollback=drop_sql, estimated_time="1-5 minutes depending on table size", risk_level="LOW" ) def _generate_table_rename_steps(self, tables_renamed: List[Dict[str, Any]]): """Generate steps for renaming tables.""" for rename_info in tables_renamed: step = self._rename_table_step(rename_info) self.migration_steps.append(step) def _rename_table_step(self, rename_info: Dict[str, Any]) -> MigrationStep: """Create step for renaming table.""" old_name = rename_info['old_name'] new_name = rename_info['new_name'] rename_sql = f"ALTER TABLE {old_name} RENAME TO {new_name};" rollback_sql = f"ALTER TABLE {new_name} RENAME TO {old_name};" return MigrationStep( step_id=self._generate_step_id(), step_type="RENAME_TABLE", table=old_name, description=f"Rename table {old_name} to {new_name}", sql_forward=rename_sql, sql_rollback=rollback_sql, validation_sql=f"SELECT COUNT(*) FROM information_schema.tables WHERE table_name = '{new_name}';", risk_level="LOW" ) def _generate_column_removal_steps(self, columns_dropped: List[Dict[str, Any]]): """Generate steps for removing columns.""" for col_info in columns_dropped: step = self._drop_column_step(col_info) self.migration_steps.append(step) def _drop_column_step(self, col_info: Dict[str, Any]) -> MigrationStep: """Create step for dropping column.""" table = col_info['table'] column = col_info['definition'] drop_sql = f"ALTER TABLE {table} DROP COLUMN {column.name};" # Recreate column for rollback col_sql = f"{column.name} {column.data_type}" if not column.nullable: col_sql += " NOT NULL" if column.default_value: col_sql += f" DEFAULT {column.default_value}" add_sql = f"ALTER TABLE {table} ADD COLUMN {col_sql};" return MigrationStep( step_id=self._generate_step_id(), step_type="DROP_COLUMN", table=table, description=f"Drop column {column.name} from {table}", sql_forward=drop_sql, sql_rollback=add_sql, risk_level="HIGH" # Data loss risk ) def _generate_constraint_removal_steps(self, constraints_dropped: List[Dict[str, Any]]): """Generate steps for removing constraints.""" for constraint_info in constraints_dropped: step = self._drop_constraint_step(constraint_info) if step: self.migration_steps.append(step) def _drop_constraint_step(self, constraint_info: Dict[str, Any]) -> Optional[MigrationStep]: """Create step for dropping constraint.""" table = constraint_info['table'] constraint_type = constraint_info['constraint_type'] if constraint_type == 'PRIMARY_KEY': constraint_name = f"pk_{table}" drop_sql = f"ALTER TABLE {table} DROP CONSTRAINT {constraint_name};" columns = constraint_info['columns'] add_sql = f"ALTER TABLE {table} ADD CONSTRAINT {constraint_name} PRIMARY KEY ({', '.join(columns)});" description = f"Drop primary key constraint" elif constraint_type == 'UNIQUE': columns = constraint_info['columns'] constraint_name = f"uq_{table}_{'_'.join(columns)}" drop_sql = f"ALTER TABLE {table} DROP CONSTRAINT {constraint_name};" add_sql = f"ALTER TABLE {table} ADD CONSTRAINT {constraint_name} UNIQUE ({', '.join(columns)});" description = f"Drop unique constraint on {', '.join(columns)}" elif constraint_type == 'CHECK': constraint_name = constraint_info['constraint_name'] condition = constraint_info.get('condition', '') drop_sql = f"ALTER TABLE {table} DROP CONSTRAINT {constraint_name};" add_sql = f"ALTER TABLE {table} ADD CONSTRAINT {constraint_name} CHECK ({condition});" description = f"Drop check constraint {constraint_name}" else: return None return MigrationStep( step_id=self._generate_step_id(), step_type="DROP_CONSTRAINT", table=table, description=description, sql_forward=drop_sql, sql_rollback=add_sql, risk_level="MEDIUM" ) def _generate_index_removal_steps(self, indexes_dropped: List[Dict[str, Any]]): """Generate steps for removing indexes.""" for index_info in indexes_dropped: step = self._drop_index_step(index_info) self.migration_steps.append(step) def _drop_index_step(self, index_info: Dict[str, Any]) -> MigrationStep: """Create step for dropping index.""" table = index_info['table'] index = index_info['index'] drop_sql = f"DROP INDEX {index['name']};" # Recreate for rollback unique_keyword = "UNIQUE " if index.get('unique', False) else "" columns_sql = ', '.join(index['columns']) create_sql = f"CREATE {unique_keyword}INDEX {index['name']} ON {table} ({columns_sql});" return MigrationStep( step_id=self._generate_step_id(), step_type="DROP_INDEX", table=table, description=f"Drop index {index['name']}", sql_forward=drop_sql, sql_rollback=create_sql, risk_level="LOW" ) def _generate_table_removal_steps(self, tables_dropped: List[Dict[str, Any]]): """Generate steps for removing tables.""" for table_info in tables_dropped: step = self._drop_table_step(table_info) self.migration_steps.append(step) def _drop_table_step(self, table_info: Dict[str, Any]) -> MigrationStep: """Create step for dropping table.""" table = table_info['definition'] drop_sql = f"DROP TABLE {table.name};" # Would need to recreate entire table for rollback # This is simplified - full implementation would generate CREATE TABLE statement create_sql = f"-- Recreate table {table.name} (implementation needed)" return MigrationStep( step_id=self._generate_step_id(), step_type="DROP_TABLE", table=table.name, description=f"Drop table {table.name}", sql_forward=drop_sql, sql_rollback=create_sql, risk_level="HIGH" # Data loss risk ) def _generate_migration_id(self, changes: Dict[str, List[Dict[str, Any]]]) -> str: """Generate unique migration ID.""" content = json.dumps(changes, sort_keys=True) return hashlib.md5(content.encode()).hexdigest()[:8] def _calculate_changes_hash(self, changes: Dict[str, List[Dict[str, Any]]]) -> str: """Calculate hash of changes for versioning.""" content = json.dumps(changes, sort_keys=True) return hashlib.md5(content.encode()).hexdigest() def _generate_summary(self, changes: Dict[str, List[Dict[str, Any]]]) -> Dict[str, Any]: """Generate migration summary.""" summary = { "total_steps": len(self.migration_steps), "changes_summary": { "tables_added": len(changes['tables_added']), "tables_dropped": len(changes['tables_dropped']), "tables_renamed": len(changes['tables_renamed']), "columns_added": len(changes['columns_added']), "columns_dropped": len(changes['columns_dropped']), "columns_modified": len(changes['columns_modified']), "constraints_added": len(changes['constraints_added']), "constraints_dropped": len(changes['constraints_dropped']), "indexes_added": len(changes['indexes_added']), "indexes_dropped": len(changes['indexes_dropped']) }, "risk_assessment": { "high_risk_steps": len([s for s in self.migration_steps if s.risk_level == "HIGH"]), "medium_risk_steps": len([s for s in self.migration_steps if s.risk_level == "MEDIUM"]), "low_risk_steps": len([s for s in self.migration_steps if s.risk_level == "LOW"]) }, "zero_downtime": self.zero_downtime } return summary class ValidationGenerator: """Generates validation queries for migration verification.""" def generate_validations(self, migration_plan: MigrationPlan) -> List[ValidationCheck]: """Generate validation checks for migration plan.""" validations = [] for step in migration_plan.steps: if step.step_type == "CREATE_TABLE": validations.append(self._create_table_validation(step)) elif step.step_type == "ADD_COLUMN": validations.append(self._add_column_validation(step)) elif step.step_type == "MODIFY_COLUMN": validations.append(self._modify_column_validation(step)) elif step.step_type == "ADD_INDEX": validations.append(self._add_index_validation(step)) return validations def _create_table_validation(self, step: MigrationStep) -> ValidationCheck: """Create validation for table creation.""" return ValidationCheck( check_id=f"validate_{step.step_id}", check_type="TABLE_EXISTS", table=step.table, description=f"Verify table {step.table} exists", sql_query=f"SELECT COUNT(*) FROM information_schema.tables WHERE table_name = '{step.table}';", expected_result=1 ) def _add_column_validation(self, step: MigrationStep) -> ValidationCheck: """Create validation for column addition.""" # Extract column name from SQL column_match = re.search(r'ADD COLUMN (\w+)', step.sql_forward) column_name = column_match.group(1) if column_match else "unknown" return ValidationCheck( check_id=f"validate_{step.step_id}", check_type="COLUMN_EXISTS", table=step.table, description=f"Verify column {column_name} exists in {step.table}", sql_query=f"SELECT COUNT(*) FROM information_schema.columns WHERE table_name = '{step.table}' AND column_name = '{column_name}';", expected_result=1 ) def _modify_column_validation(self, step: MigrationStep) -> ValidationCheck: """Create validation for column modification.""" return ValidationCheck( check_id=f"validate_{step.step_id}", check_type="COLUMN_MODIFIED", table=step.table, description=f"Verify column modification in {step.table}", sql_query=step.validation_sql or f"SELECT 1;", # Use provided validation or default expected_result=1 ) def _add_index_validation(self, step: MigrationStep) -> ValidationCheck: """Create validation for index addition.""" # Extract index name from SQL index_match = re.search(r'INDEX (\w+)', step.sql_forward) index_name = index_match.group(1) if index_match else "unknown" return ValidationCheck( check_id=f"validate_{step.step_id}", check_type="INDEX_EXISTS", table=step.table, description=f"Verify index {index_name} exists", sql_query=f"SELECT COUNT(*) FROM information_schema.statistics WHERE index_name = '{index_name}';", expected_result=1 ) def format_migration_plan_text(plan: MigrationPlan, validations: List[ValidationCheck] = None) -> str: """Format migration plan as human-readable text.""" lines = [] lines.append("DATABASE MIGRATION PLAN") lines.append("=" * 50) lines.append(f"Migration ID: {plan.migration_id}") lines.append(f"Created: {plan.created_at}") lines.append(f"Zero Downtime: {plan.summary['zero_downtime']}") lines.append("") # Summary summary = plan.summary lines.append("MIGRATION SUMMARY") lines.append("-" * 17) lines.append(f"Total Steps: {summary['total_steps']}") changes = summary['changes_summary'] for change_type, count in changes.items(): if count > 0: lines.append(f"{change_type.replace('_', ' ').title()}: {count}") lines.append("") # Risk Assessment risk = summary['risk_assessment'] lines.append("RISK ASSESSMENT") lines.append("-" * 15) lines.append(f"High Risk Steps: {risk['high_risk_steps']}") lines.append(f"Medium Risk Steps: {risk['medium_risk_steps']}") lines.append(f"Low Risk Steps: {risk['low_risk_steps']}") lines.append("") # Migration Steps lines.append("MIGRATION STEPS") lines.append("-" * 15) for i, step in enumerate(plan.steps, 1): lines.append(f"{i}. {step.description} ({step.risk_level} risk)") lines.append(f" Type: {step.step_type}") if step.zero_downtime_phase: lines.append(f" Phase: {step.zero_downtime_phase}") lines.append(f" Forward SQL: {step.sql_forward}") lines.append(f" Rollback SQL: {step.sql_rollback}") if step.estimated_time: lines.append(f" Estimated Time: {step.estimated_time}") lines.append("") # Validation Checks if validations: lines.append("VALIDATION CHECKS") lines.append("-" * 17) for validation in validations: lines.append(f"• {validation.description}") lines.append(f" SQL: {validation.sql_query}") lines.append(f" Expected: {validation.expected_result}") lines.append("") return "\n".join(lines) def main(): parser = argparse.ArgumentParser(description="Generate database migration scripts") parser.add_argument("--current", "-c", required=True, help="Current schema JSON file") parser.add_argument("--target", "-t", required=True, help="Target schema JSON file") parser.add_argument("--output", "-o", help="Output file (default: stdout)") parser.add_argument("--format", "-f", choices=["json", "text", "sql"], default="text", help="Output format") parser.add_argument("--zero-downtime", "-z", action="store_true", help="Generate zero-downtime migration strategy") parser.add_argument("--validate-only", "-v", action="store_true", help="Only generate validation queries") parser.add_argument("--include-validations", action="store_true", help="Include validation queries in output") args = parser.parse_args() try: # Load schemas with open(args.current, 'r') as f: current_schema = json.load(f) with open(args.target, 'r') as f: target_schema = json.load(f) # Compare schemas comparator = SchemaComparator() comparator.load_schemas(current_schema, target_schema) changes = comparator.compare_schemas() if not any(changes.values()): print("No schema changes detected.") return 0 # Generate migration generator = MigrationGenerator(zero_downtime=args.zero_downtime) migration_plan = generator.generate_migration(changes) # Generate validations if requested validations = None if args.include_validations or args.validate_only: validator = ValidationGenerator() validations = validator.generate_validations(migration_plan) # Format output if args.validate_only: output = json.dumps([asdict(v) for v in validations], indent=2) elif args.format == "json": result = {"migration_plan": asdict(migration_plan)} if validations: result["validations"] = [asdict(v) for v in validations] output = json.dumps(result, indent=2) elif args.format == "sql": sql_lines = [] sql_lines.append("-- Database Migration Script") sql_lines.append(f"-- Migration ID: {migration_plan.migration_id}") sql_lines.append(f"-- Created: {migration_plan.created_at}") sql_lines.append("") for step in migration_plan.steps: sql_lines.append(f"-- Step: {step.description}") sql_lines.append(step.sql_forward) sql_lines.append("") output = "\n".join(sql_lines) else: # text format output = format_migration_plan_text(migration_plan, validations) # Write output if args.output: with open(args.output, 'w') as f: f.write(output) else: print(output) return 0 except Exception as e: print(f"Error: {e}", file=sys.stderr) return 1 if __name__ == "__main__": sys.exit(main())