Files
CleanArchitecture-template/.brain/.agent/skills/engineering-advanced-skills/database-designer/migration_generator.py
2026-03-12 15:17:52 +07:00

1199 lines
49 KiB
Python

#!/usr/bin/env python3
"""
Database Migration Generator
Generates safe migration scripts between schema versions:
- Compares current and target schemas
- Generates ALTER TABLE statements for schema changes
- Implements zero-downtime migration strategies (expand-contract pattern)
- Creates rollback scripts for all changes
- Generates validation queries to verify migrations
- Handles complex changes like table splits/merges
Input: Current schema JSON + Target schema JSON
Output: Migration SQL + Rollback SQL + Validation queries + Execution plan
Usage:
python migration_generator.py --current current_schema.json --target target_schema.json --output migration.sql
python migration_generator.py --current current.json --target target.json --format json
python migration_generator.py --current current.json --target target.json --zero-downtime
python migration_generator.py --current current.json --target target.json --validate-only
"""
import argparse
import json
import re
import sys
from collections import defaultdict, OrderedDict
from typing import Dict, List, Set, Tuple, Optional, Any, Union
from dataclasses import dataclass, asdict
from datetime import datetime
import hashlib
@dataclass
class Column:
name: str
data_type: str
nullable: bool = True
primary_key: bool = False
unique: bool = False
foreign_key: Optional[str] = None
default_value: Optional[str] = None
check_constraint: Optional[str] = None
@dataclass
class Table:
name: str
columns: Dict[str, Column]
primary_key: List[str]
foreign_keys: Dict[str, str] # column -> referenced_table.column
unique_constraints: List[List[str]]
check_constraints: Dict[str, str]
indexes: List[Dict[str, Any]]
@dataclass
class MigrationStep:
step_id: str
step_type: str
table: str
description: str
sql_forward: str
sql_rollback: str
validation_sql: Optional[str] = None
dependencies: List[str] = None
risk_level: str = "LOW" # LOW, MEDIUM, HIGH
estimated_time: Optional[str] = None
zero_downtime_phase: Optional[str] = None # EXPAND, CONTRACT, or None
@dataclass
class MigrationPlan:
migration_id: str
created_at: str
source_schema_hash: str
target_schema_hash: str
steps: List[MigrationStep]
summary: Dict[str, Any]
execution_order: List[str]
rollback_order: List[str]
@dataclass
class ValidationCheck:
check_id: str
check_type: str
table: str
description: str
sql_query: str
expected_result: Any
critical: bool = True
class SchemaComparator:
"""Compares two schema versions and identifies differences."""
def __init__(self):
self.current_schema: Dict[str, Table] = {}
self.target_schema: Dict[str, Table] = {}
self.changes: Dict[str, List[Dict[str, Any]]] = {
'tables_added': [],
'tables_dropped': [],
'tables_renamed': [],
'columns_added': [],
'columns_dropped': [],
'columns_modified': [],
'columns_renamed': [],
'constraints_added': [],
'constraints_dropped': [],
'indexes_added': [],
'indexes_dropped': []
}
def load_schemas(self, current_data: Dict[str, Any], target_data: Dict[str, Any]):
"""Load current and target schemas."""
self.current_schema = self._parse_schema(current_data)
self.target_schema = self._parse_schema(target_data)
def _parse_schema(self, schema_data: Dict[str, Any]) -> Dict[str, Table]:
"""Parse schema JSON into Table objects."""
tables = {}
if 'tables' not in schema_data:
return tables
for table_name, table_def in schema_data['tables'].items():
columns = {}
primary_key = table_def.get('primary_key', [])
foreign_keys = {}
# Parse columns
for col_name, col_def in table_def.get('columns', {}).items():
column = Column(
name=col_name,
data_type=col_def.get('type', 'VARCHAR(255)'),
nullable=col_def.get('nullable', True),
primary_key=col_name in primary_key,
unique=col_def.get('unique', False),
foreign_key=col_def.get('foreign_key'),
default_value=col_def.get('default'),
check_constraint=col_def.get('check_constraint')
)
columns[col_name] = column
if column.foreign_key:
foreign_keys[col_name] = column.foreign_key
table = Table(
name=table_name,
columns=columns,
primary_key=primary_key,
foreign_keys=foreign_keys,
unique_constraints=table_def.get('unique_constraints', []),
check_constraints=table_def.get('check_constraints', {}),
indexes=table_def.get('indexes', [])
)
tables[table_name] = table
return tables
def compare_schemas(self) -> Dict[str, List[Dict[str, Any]]]:
"""Compare schemas and identify all changes."""
self._compare_tables()
self._compare_columns()
self._compare_constraints()
self._compare_indexes()
return self.changes
def _compare_tables(self):
"""Compare table-level changes."""
current_tables = set(self.current_schema.keys())
target_tables = set(self.target_schema.keys())
# Tables added
for table_name in target_tables - current_tables:
self.changes['tables_added'].append({
'table': table_name,
'definition': self.target_schema[table_name]
})
# Tables dropped
for table_name in current_tables - target_tables:
self.changes['tables_dropped'].append({
'table': table_name,
'definition': self.current_schema[table_name]
})
# Tables renamed (heuristic based on column similarity)
self._detect_renamed_tables(current_tables - target_tables, target_tables - current_tables)
def _detect_renamed_tables(self, dropped_tables: Set[str], added_tables: Set[str]):
"""Detect renamed tables based on column similarity."""
if not dropped_tables or not added_tables:
return
# Calculate similarity scores
similarity_scores = []
for dropped_table in dropped_tables:
for added_table in added_tables:
score = self._calculate_table_similarity(dropped_table, added_table)
if score > 0.7: # High similarity threshold
similarity_scores.append((score, dropped_table, added_table))
# Sort by similarity and identify renames
similarity_scores.sort(reverse=True)
used_tables = set()
for score, old_name, new_name in similarity_scores:
if old_name not in used_tables and new_name not in used_tables:
self.changes['tables_renamed'].append({
'old_name': old_name,
'new_name': new_name,
'similarity_score': score
})
used_tables.add(old_name)
used_tables.add(new_name)
# Remove from added/dropped lists
self.changes['tables_added'] = [t for t in self.changes['tables_added'] if t['table'] != new_name]
self.changes['tables_dropped'] = [t for t in self.changes['tables_dropped'] if t['table'] != old_name]
def _calculate_table_similarity(self, table1_name: str, table2_name: str) -> float:
"""Calculate similarity between two tables based on columns."""
table1 = self.current_schema[table1_name]
table2 = self.target_schema[table2_name]
cols1 = set(table1.columns.keys())
cols2 = set(table2.columns.keys())
if not cols1 and not cols2:
return 1.0
elif not cols1 or not cols2:
return 0.0
intersection = len(cols1.intersection(cols2))
union = len(cols1.union(cols2))
return intersection / union
def _compare_columns(self):
"""Compare column-level changes."""
common_tables = set(self.current_schema.keys()).intersection(set(self.target_schema.keys()))
for table_name in common_tables:
current_table = self.current_schema[table_name]
target_table = self.target_schema[table_name]
current_columns = set(current_table.columns.keys())
target_columns = set(target_table.columns.keys())
# Columns added
for col_name in target_columns - current_columns:
self.changes['columns_added'].append({
'table': table_name,
'column': col_name,
'definition': target_table.columns[col_name]
})
# Columns dropped
for col_name in current_columns - target_columns:
self.changes['columns_dropped'].append({
'table': table_name,
'column': col_name,
'definition': current_table.columns[col_name]
})
# Columns modified
for col_name in current_columns.intersection(target_columns):
current_col = current_table.columns[col_name]
target_col = target_table.columns[col_name]
if self._columns_different(current_col, target_col):
self.changes['columns_modified'].append({
'table': table_name,
'column': col_name,
'current_definition': current_col,
'target_definition': target_col,
'changes': self._describe_column_changes(current_col, target_col)
})
def _columns_different(self, col1: Column, col2: Column) -> bool:
"""Check if two columns have different definitions."""
return (col1.data_type != col2.data_type or
col1.nullable != col2.nullable or
col1.default_value != col2.default_value or
col1.unique != col2.unique or
col1.foreign_key != col2.foreign_key or
col1.check_constraint != col2.check_constraint)
def _describe_column_changes(self, current_col: Column, target_col: Column) -> List[str]:
"""Describe specific changes between column definitions."""
changes = []
if current_col.data_type != target_col.data_type:
changes.append(f"type: {current_col.data_type} -> {target_col.data_type}")
if current_col.nullable != target_col.nullable:
changes.append(f"nullable: {current_col.nullable} -> {target_col.nullable}")
if current_col.default_value != target_col.default_value:
changes.append(f"default: {current_col.default_value} -> {target_col.default_value}")
if current_col.unique != target_col.unique:
changes.append(f"unique: {current_col.unique} -> {target_col.unique}")
if current_col.foreign_key != target_col.foreign_key:
changes.append(f"foreign_key: {current_col.foreign_key} -> {target_col.foreign_key}")
return changes
def _compare_constraints(self):
"""Compare constraint changes."""
common_tables = set(self.current_schema.keys()).intersection(set(self.target_schema.keys()))
for table_name in common_tables:
current_table = self.current_schema[table_name]
target_table = self.target_schema[table_name]
# Compare primary keys
if current_table.primary_key != target_table.primary_key:
if current_table.primary_key:
self.changes['constraints_dropped'].append({
'table': table_name,
'constraint_type': 'PRIMARY_KEY',
'columns': current_table.primary_key
})
if target_table.primary_key:
self.changes['constraints_added'].append({
'table': table_name,
'constraint_type': 'PRIMARY_KEY',
'columns': target_table.primary_key
})
# Compare unique constraints
current_unique = set(tuple(uc) for uc in current_table.unique_constraints)
target_unique = set(tuple(uc) for uc in target_table.unique_constraints)
for constraint in target_unique - current_unique:
self.changes['constraints_added'].append({
'table': table_name,
'constraint_type': 'UNIQUE',
'columns': list(constraint)
})
for constraint in current_unique - target_unique:
self.changes['constraints_dropped'].append({
'table': table_name,
'constraint_type': 'UNIQUE',
'columns': list(constraint)
})
# Compare check constraints
current_checks = set(current_table.check_constraints.items())
target_checks = set(target_table.check_constraints.items())
for name, condition in target_checks - current_checks:
self.changes['constraints_added'].append({
'table': table_name,
'constraint_type': 'CHECK',
'constraint_name': name,
'condition': condition
})
for name, condition in current_checks - target_checks:
self.changes['constraints_dropped'].append({
'table': table_name,
'constraint_type': 'CHECK',
'constraint_name': name,
'condition': condition
})
def _compare_indexes(self):
"""Compare index changes."""
common_tables = set(self.current_schema.keys()).intersection(set(self.target_schema.keys()))
for table_name in common_tables:
current_indexes = {idx['name']: idx for idx in self.current_schema[table_name].indexes}
target_indexes = {idx['name']: idx for idx in self.target_schema[table_name].indexes}
current_names = set(current_indexes.keys())
target_names = set(target_indexes.keys())
# Indexes added
for idx_name in target_names - current_names:
self.changes['indexes_added'].append({
'table': table_name,
'index': target_indexes[idx_name]
})
# Indexes dropped
for idx_name in current_names - target_names:
self.changes['indexes_dropped'].append({
'table': table_name,
'index': current_indexes[idx_name]
})
class MigrationGenerator:
"""Generates migration steps from schema differences."""
def __init__(self, zero_downtime: bool = False):
self.zero_downtime = zero_downtime
self.migration_steps: List[MigrationStep] = []
self.step_counter = 0
# Data type conversion safety
self.safe_type_conversions = {
('VARCHAR(50)', 'VARCHAR(100)'): True, # Expanding varchar
('INT', 'BIGINT'): True, # Expanding integer
('DECIMAL(10,2)', 'DECIMAL(12,2)'): True, # Expanding decimal precision
}
self.risky_type_conversions = {
('VARCHAR(100)', 'VARCHAR(50)'): 'Data truncation possible',
('BIGINT', 'INT'): 'Data loss possible for large values',
('TEXT', 'VARCHAR(255)'): 'Data truncation possible'
}
def generate_migration(self, changes: Dict[str, List[Dict[str, Any]]]) -> MigrationPlan:
"""Generate complete migration plan from schema changes."""
self.migration_steps = []
self.step_counter = 0
# Generate steps in dependency order
self._generate_table_creation_steps(changes['tables_added'])
self._generate_column_addition_steps(changes['columns_added'])
self._generate_constraint_addition_steps(changes['constraints_added'])
self._generate_index_addition_steps(changes['indexes_added'])
self._generate_column_modification_steps(changes['columns_modified'])
self._generate_table_rename_steps(changes['tables_renamed'])
self._generate_index_removal_steps(changes['indexes_dropped'])
self._generate_constraint_removal_steps(changes['constraints_dropped'])
self._generate_column_removal_steps(changes['columns_dropped'])
self._generate_table_removal_steps(changes['tables_dropped'])
# Create migration plan
migration_id = self._generate_migration_id(changes)
execution_order = [step.step_id for step in self.migration_steps]
rollback_order = list(reversed(execution_order))
return MigrationPlan(
migration_id=migration_id,
created_at=datetime.now().isoformat(),
source_schema_hash=self._calculate_changes_hash(changes),
target_schema_hash="", # Would be calculated from target schema
steps=self.migration_steps,
summary=self._generate_summary(changes),
execution_order=execution_order,
rollback_order=rollback_order
)
def _generate_step_id(self) -> str:
"""Generate unique step ID."""
self.step_counter += 1
return f"step_{self.step_counter:03d}"
def _generate_table_creation_steps(self, tables_added: List[Dict[str, Any]]):
"""Generate steps for creating new tables."""
for table_info in tables_added:
table = table_info['definition']
step = self._create_table_step(table)
self.migration_steps.append(step)
def _create_table_step(self, table: Table) -> MigrationStep:
"""Create migration step for table creation."""
columns_sql = []
for col_name, column in table.columns.items():
col_sql = f"{col_name} {column.data_type}"
if not column.nullable:
col_sql += " NOT NULL"
if column.default_value:
col_sql += f" DEFAULT {column.default_value}"
if column.unique:
col_sql += " UNIQUE"
columns_sql.append(col_sql)
# Add primary key
if table.primary_key:
pk_sql = f"PRIMARY KEY ({', '.join(table.primary_key)})"
columns_sql.append(pk_sql)
# Add foreign keys
for col_name, ref in table.foreign_keys.items():
fk_sql = f"FOREIGN KEY ({col_name}) REFERENCES {ref}"
columns_sql.append(fk_sql)
create_sql = f"CREATE TABLE {table.name} (\n " + ",\n ".join(columns_sql) + "\n);"
drop_sql = f"DROP TABLE IF EXISTS {table.name};"
return MigrationStep(
step_id=self._generate_step_id(),
step_type="CREATE_TABLE",
table=table.name,
description=f"Create table {table.name} with {len(table.columns)} columns",
sql_forward=create_sql,
sql_rollback=drop_sql,
validation_sql=f"SELECT COUNT(*) FROM information_schema.tables WHERE table_name = '{table.name}';",
risk_level="LOW"
)
def _generate_column_addition_steps(self, columns_added: List[Dict[str, Any]]):
"""Generate steps for adding columns."""
for col_info in columns_added:
if self.zero_downtime:
# For zero-downtime, add columns as nullable first
step = self._add_column_zero_downtime_step(col_info)
else:
step = self._add_column_step(col_info)
self.migration_steps.append(step)
def _add_column_step(self, col_info: Dict[str, Any]) -> MigrationStep:
"""Create step for adding a column."""
table = col_info['table']
column = col_info['definition']
col_sql = f"{column.name} {column.data_type}"
if not column.nullable:
if column.default_value:
col_sql += f" DEFAULT {column.default_value} NOT NULL"
else:
# This is risky - adding NOT NULL without default
col_sql += " NOT NULL"
elif column.default_value:
col_sql += f" DEFAULT {column.default_value}"
add_sql = f"ALTER TABLE {table} ADD COLUMN {col_sql};"
drop_sql = f"ALTER TABLE {table} DROP COLUMN {column.name};"
risk_level = "HIGH" if not column.nullable and not column.default_value else "LOW"
return MigrationStep(
step_id=self._generate_step_id(),
step_type="ADD_COLUMN",
table=table,
description=f"Add column {column.name} to {table}",
sql_forward=add_sql,
sql_rollback=drop_sql,
validation_sql=f"SELECT COUNT(*) FROM information_schema.columns WHERE table_name = '{table}' AND column_name = '{column.name}';",
risk_level=risk_level
)
def _add_column_zero_downtime_step(self, col_info: Dict[str, Any]) -> MigrationStep:
"""Create zero-downtime step for adding column."""
table = col_info['table']
column = col_info['definition']
# Phase 1: Add as nullable with default if needed
col_sql = f"{column.name} {column.data_type}"
if column.default_value:
col_sql += f" DEFAULT {column.default_value}"
add_sql = f"ALTER TABLE {table} ADD COLUMN {col_sql};"
# If column should be NOT NULL, handle in separate phase
if not column.nullable:
# Add comment about needing follow-up step
add_sql += f"\n-- Follow-up needed: Add NOT NULL constraint after data population"
drop_sql = f"ALTER TABLE {table} DROP COLUMN {column.name};"
return MigrationStep(
step_id=self._generate_step_id(),
step_type="ADD_COLUMN_ZD",
table=table,
description=f"Add column {column.name} to {table} (zero-downtime phase 1)",
sql_forward=add_sql,
sql_rollback=drop_sql,
validation_sql=f"SELECT COUNT(*) FROM information_schema.columns WHERE table_name = '{table}' AND column_name = '{column.name}';",
risk_level="LOW",
zero_downtime_phase="EXPAND"
)
def _generate_column_modification_steps(self, columns_modified: List[Dict[str, Any]]):
"""Generate steps for modifying columns."""
for col_info in columns_modified:
if self.zero_downtime:
steps = self._modify_column_zero_downtime_steps(col_info)
self.migration_steps.extend(steps)
else:
step = self._modify_column_step(col_info)
self.migration_steps.append(step)
def _modify_column_step(self, col_info: Dict[str, Any]) -> MigrationStep:
"""Create step for modifying a column."""
table = col_info['table']
column = col_info['column']
current_def = col_info['current_definition']
target_def = col_info['target_definition']
changes = col_info['changes']
alter_statements = []
rollback_statements = []
# Handle different types of changes
if current_def.data_type != target_def.data_type:
alter_statements.append(f"ALTER COLUMN {column} TYPE {target_def.data_type}")
rollback_statements.append(f"ALTER COLUMN {column} TYPE {current_def.data_type}")
if current_def.nullable != target_def.nullable:
if target_def.nullable:
alter_statements.append(f"ALTER COLUMN {column} DROP NOT NULL")
rollback_statements.append(f"ALTER COLUMN {column} SET NOT NULL")
else:
alter_statements.append(f"ALTER COLUMN {column} SET NOT NULL")
rollback_statements.append(f"ALTER COLUMN {column} DROP NOT NULL")
if current_def.default_value != target_def.default_value:
if target_def.default_value:
alter_statements.append(f"ALTER COLUMN {column} SET DEFAULT {target_def.default_value}")
else:
alter_statements.append(f"ALTER COLUMN {column} DROP DEFAULT")
if current_def.default_value:
rollback_statements.append(f"ALTER COLUMN {column} SET DEFAULT {current_def.default_value}")
else:
rollback_statements.append(f"ALTER COLUMN {column} DROP DEFAULT")
# Build SQL
alter_sql = f"ALTER TABLE {table}\n " + ",\n ".join(alter_statements) + ";"
rollback_sql = f"ALTER TABLE {table}\n " + ",\n ".join(rollback_statements) + ";"
# Assess risk
risk_level = self._assess_column_modification_risk(current_def, target_def)
return MigrationStep(
step_id=self._generate_step_id(),
step_type="MODIFY_COLUMN",
table=table,
description=f"Modify column {column}: {', '.join(changes)}",
sql_forward=alter_sql,
sql_rollback=rollback_sql,
validation_sql=f"SELECT data_type, is_nullable FROM information_schema.columns WHERE table_name = '{table}' AND column_name = '{column}';",
risk_level=risk_level
)
def _modify_column_zero_downtime_steps(self, col_info: Dict[str, Any]) -> List[MigrationStep]:
"""Create zero-downtime steps for column modification."""
table = col_info['table']
column = col_info['column']
current_def = col_info['current_definition']
target_def = col_info['target_definition']
steps = []
# For zero-downtime, use expand-contract pattern
temp_column = f"{column}_new"
# Step 1: Add new column
step1 = MigrationStep(
step_id=self._generate_step_id(),
step_type="ADD_TEMP_COLUMN",
table=table,
description=f"Add temporary column {temp_column} for zero-downtime migration",
sql_forward=f"ALTER TABLE {table} ADD COLUMN {temp_column} {target_def.data_type};",
sql_rollback=f"ALTER TABLE {table} DROP COLUMN {temp_column};",
zero_downtime_phase="EXPAND"
)
steps.append(step1)
# Step 2: Copy data
step2 = MigrationStep(
step_id=self._generate_step_id(),
step_type="COPY_COLUMN_DATA",
table=table,
description=f"Copy data from {column} to {temp_column}",
sql_forward=f"UPDATE {table} SET {temp_column} = {column};",
sql_rollback=f"UPDATE {table} SET {temp_column} = NULL;",
zero_downtime_phase="EXPAND"
)
steps.append(step2)
# Step 3: Drop old column
step3 = MigrationStep(
step_id=self._generate_step_id(),
step_type="DROP_OLD_COLUMN",
table=table,
description=f"Drop original column {column}",
sql_forward=f"ALTER TABLE {table} DROP COLUMN {column};",
sql_rollback=f"ALTER TABLE {table} ADD COLUMN {column} {current_def.data_type};",
zero_downtime_phase="CONTRACT"
)
steps.append(step3)
# Step 4: Rename new column
step4 = MigrationStep(
step_id=self._generate_step_id(),
step_type="RENAME_COLUMN",
table=table,
description=f"Rename {temp_column} to {column}",
sql_forward=f"ALTER TABLE {table} RENAME COLUMN {temp_column} TO {column};",
sql_rollback=f"ALTER TABLE {table} RENAME COLUMN {column} TO {temp_column};",
zero_downtime_phase="CONTRACT"
)
steps.append(step4)
return steps
def _assess_column_modification_risk(self, current: Column, target: Column) -> str:
"""Assess risk level of column modification."""
if current.data_type != target.data_type:
conversion_key = (current.data_type, target.data_type)
if conversion_key in self.risky_type_conversions:
return "HIGH"
elif conversion_key not in self.safe_type_conversions:
return "MEDIUM"
if current.nullable and not target.nullable:
return "HIGH" # Adding NOT NULL constraint
return "LOW"
def _generate_constraint_addition_steps(self, constraints_added: List[Dict[str, Any]]):
"""Generate steps for adding constraints."""
for constraint_info in constraints_added:
step = self._add_constraint_step(constraint_info)
self.migration_steps.append(step)
def _add_constraint_step(self, constraint_info: Dict[str, Any]) -> MigrationStep:
"""Create step for adding constraint."""
table = constraint_info['table']
constraint_type = constraint_info['constraint_type']
if constraint_type == 'PRIMARY_KEY':
columns = constraint_info['columns']
constraint_name = f"pk_{table}"
add_sql = f"ALTER TABLE {table} ADD CONSTRAINT {constraint_name} PRIMARY KEY ({', '.join(columns)});"
drop_sql = f"ALTER TABLE {table} DROP CONSTRAINT {constraint_name};"
description = f"Add primary key on {', '.join(columns)}"
elif constraint_type == 'UNIQUE':
columns = constraint_info['columns']
constraint_name = f"uq_{table}_{'_'.join(columns)}"
add_sql = f"ALTER TABLE {table} ADD CONSTRAINT {constraint_name} UNIQUE ({', '.join(columns)});"
drop_sql = f"ALTER TABLE {table} DROP CONSTRAINT {constraint_name};"
description = f"Add unique constraint on {', '.join(columns)}"
elif constraint_type == 'CHECK':
constraint_name = constraint_info['constraint_name']
condition = constraint_info['condition']
add_sql = f"ALTER TABLE {table} ADD CONSTRAINT {constraint_name} CHECK ({condition});"
drop_sql = f"ALTER TABLE {table} DROP CONSTRAINT {constraint_name};"
description = f"Add check constraint: {condition}"
else:
return None
return MigrationStep(
step_id=self._generate_step_id(),
step_type="ADD_CONSTRAINT",
table=table,
description=description,
sql_forward=add_sql,
sql_rollback=drop_sql,
risk_level="MEDIUM" # Constraints can fail if data doesn't comply
)
def _generate_index_addition_steps(self, indexes_added: List[Dict[str, Any]]):
"""Generate steps for adding indexes."""
for index_info in indexes_added:
step = self._add_index_step(index_info)
self.migration_steps.append(step)
def _add_index_step(self, index_info: Dict[str, Any]) -> MigrationStep:
"""Create step for adding index."""
table = index_info['table']
index = index_info['index']
unique_keyword = "UNIQUE " if index.get('unique', False) else ""
columns_sql = ', '.join(index['columns'])
create_sql = f"CREATE {unique_keyword}INDEX {index['name']} ON {table} ({columns_sql});"
drop_sql = f"DROP INDEX {index['name']};"
return MigrationStep(
step_id=self._generate_step_id(),
step_type="ADD_INDEX",
table=table,
description=f"Create index {index['name']} on ({columns_sql})",
sql_forward=create_sql,
sql_rollback=drop_sql,
estimated_time="1-5 minutes depending on table size",
risk_level="LOW"
)
def _generate_table_rename_steps(self, tables_renamed: List[Dict[str, Any]]):
"""Generate steps for renaming tables."""
for rename_info in tables_renamed:
step = self._rename_table_step(rename_info)
self.migration_steps.append(step)
def _rename_table_step(self, rename_info: Dict[str, Any]) -> MigrationStep:
"""Create step for renaming table."""
old_name = rename_info['old_name']
new_name = rename_info['new_name']
rename_sql = f"ALTER TABLE {old_name} RENAME TO {new_name};"
rollback_sql = f"ALTER TABLE {new_name} RENAME TO {old_name};"
return MigrationStep(
step_id=self._generate_step_id(),
step_type="RENAME_TABLE",
table=old_name,
description=f"Rename table {old_name} to {new_name}",
sql_forward=rename_sql,
sql_rollback=rollback_sql,
validation_sql=f"SELECT COUNT(*) FROM information_schema.tables WHERE table_name = '{new_name}';",
risk_level="LOW"
)
def _generate_column_removal_steps(self, columns_dropped: List[Dict[str, Any]]):
"""Generate steps for removing columns."""
for col_info in columns_dropped:
step = self._drop_column_step(col_info)
self.migration_steps.append(step)
def _drop_column_step(self, col_info: Dict[str, Any]) -> MigrationStep:
"""Create step for dropping column."""
table = col_info['table']
column = col_info['definition']
drop_sql = f"ALTER TABLE {table} DROP COLUMN {column.name};"
# Recreate column for rollback
col_sql = f"{column.name} {column.data_type}"
if not column.nullable:
col_sql += " NOT NULL"
if column.default_value:
col_sql += f" DEFAULT {column.default_value}"
add_sql = f"ALTER TABLE {table} ADD COLUMN {col_sql};"
return MigrationStep(
step_id=self._generate_step_id(),
step_type="DROP_COLUMN",
table=table,
description=f"Drop column {column.name} from {table}",
sql_forward=drop_sql,
sql_rollback=add_sql,
risk_level="HIGH" # Data loss risk
)
def _generate_constraint_removal_steps(self, constraints_dropped: List[Dict[str, Any]]):
"""Generate steps for removing constraints."""
for constraint_info in constraints_dropped:
step = self._drop_constraint_step(constraint_info)
if step:
self.migration_steps.append(step)
def _drop_constraint_step(self, constraint_info: Dict[str, Any]) -> Optional[MigrationStep]:
"""Create step for dropping constraint."""
table = constraint_info['table']
constraint_type = constraint_info['constraint_type']
if constraint_type == 'PRIMARY_KEY':
constraint_name = f"pk_{table}"
drop_sql = f"ALTER TABLE {table} DROP CONSTRAINT {constraint_name};"
columns = constraint_info['columns']
add_sql = f"ALTER TABLE {table} ADD CONSTRAINT {constraint_name} PRIMARY KEY ({', '.join(columns)});"
description = f"Drop primary key constraint"
elif constraint_type == 'UNIQUE':
columns = constraint_info['columns']
constraint_name = f"uq_{table}_{'_'.join(columns)}"
drop_sql = f"ALTER TABLE {table} DROP CONSTRAINT {constraint_name};"
add_sql = f"ALTER TABLE {table} ADD CONSTRAINT {constraint_name} UNIQUE ({', '.join(columns)});"
description = f"Drop unique constraint on {', '.join(columns)}"
elif constraint_type == 'CHECK':
constraint_name = constraint_info['constraint_name']
condition = constraint_info.get('condition', '')
drop_sql = f"ALTER TABLE {table} DROP CONSTRAINT {constraint_name};"
add_sql = f"ALTER TABLE {table} ADD CONSTRAINT {constraint_name} CHECK ({condition});"
description = f"Drop check constraint {constraint_name}"
else:
return None
return MigrationStep(
step_id=self._generate_step_id(),
step_type="DROP_CONSTRAINT",
table=table,
description=description,
sql_forward=drop_sql,
sql_rollback=add_sql,
risk_level="MEDIUM"
)
def _generate_index_removal_steps(self, indexes_dropped: List[Dict[str, Any]]):
"""Generate steps for removing indexes."""
for index_info in indexes_dropped:
step = self._drop_index_step(index_info)
self.migration_steps.append(step)
def _drop_index_step(self, index_info: Dict[str, Any]) -> MigrationStep:
"""Create step for dropping index."""
table = index_info['table']
index = index_info['index']
drop_sql = f"DROP INDEX {index['name']};"
# Recreate for rollback
unique_keyword = "UNIQUE " if index.get('unique', False) else ""
columns_sql = ', '.join(index['columns'])
create_sql = f"CREATE {unique_keyword}INDEX {index['name']} ON {table} ({columns_sql});"
return MigrationStep(
step_id=self._generate_step_id(),
step_type="DROP_INDEX",
table=table,
description=f"Drop index {index['name']}",
sql_forward=drop_sql,
sql_rollback=create_sql,
risk_level="LOW"
)
def _generate_table_removal_steps(self, tables_dropped: List[Dict[str, Any]]):
"""Generate steps for removing tables."""
for table_info in tables_dropped:
step = self._drop_table_step(table_info)
self.migration_steps.append(step)
def _drop_table_step(self, table_info: Dict[str, Any]) -> MigrationStep:
"""Create step for dropping table."""
table = table_info['definition']
drop_sql = f"DROP TABLE {table.name};"
# Would need to recreate entire table for rollback
# This is simplified - full implementation would generate CREATE TABLE statement
create_sql = f"-- Recreate table {table.name} (implementation needed)"
return MigrationStep(
step_id=self._generate_step_id(),
step_type="DROP_TABLE",
table=table.name,
description=f"Drop table {table.name}",
sql_forward=drop_sql,
sql_rollback=create_sql,
risk_level="HIGH" # Data loss risk
)
def _generate_migration_id(self, changes: Dict[str, List[Dict[str, Any]]]) -> str:
"""Generate unique migration ID."""
content = json.dumps(changes, sort_keys=True)
return hashlib.md5(content.encode()).hexdigest()[:8]
def _calculate_changes_hash(self, changes: Dict[str, List[Dict[str, Any]]]) -> str:
"""Calculate hash of changes for versioning."""
content = json.dumps(changes, sort_keys=True)
return hashlib.md5(content.encode()).hexdigest()
def _generate_summary(self, changes: Dict[str, List[Dict[str, Any]]]) -> Dict[str, Any]:
"""Generate migration summary."""
summary = {
"total_steps": len(self.migration_steps),
"changes_summary": {
"tables_added": len(changes['tables_added']),
"tables_dropped": len(changes['tables_dropped']),
"tables_renamed": len(changes['tables_renamed']),
"columns_added": len(changes['columns_added']),
"columns_dropped": len(changes['columns_dropped']),
"columns_modified": len(changes['columns_modified']),
"constraints_added": len(changes['constraints_added']),
"constraints_dropped": len(changes['constraints_dropped']),
"indexes_added": len(changes['indexes_added']),
"indexes_dropped": len(changes['indexes_dropped'])
},
"risk_assessment": {
"high_risk_steps": len([s for s in self.migration_steps if s.risk_level == "HIGH"]),
"medium_risk_steps": len([s for s in self.migration_steps if s.risk_level == "MEDIUM"]),
"low_risk_steps": len([s for s in self.migration_steps if s.risk_level == "LOW"])
},
"zero_downtime": self.zero_downtime
}
return summary
class ValidationGenerator:
"""Generates validation queries for migration verification."""
def generate_validations(self, migration_plan: MigrationPlan) -> List[ValidationCheck]:
"""Generate validation checks for migration plan."""
validations = []
for step in migration_plan.steps:
if step.step_type == "CREATE_TABLE":
validations.append(self._create_table_validation(step))
elif step.step_type == "ADD_COLUMN":
validations.append(self._add_column_validation(step))
elif step.step_type == "MODIFY_COLUMN":
validations.append(self._modify_column_validation(step))
elif step.step_type == "ADD_INDEX":
validations.append(self._add_index_validation(step))
return validations
def _create_table_validation(self, step: MigrationStep) -> ValidationCheck:
"""Create validation for table creation."""
return ValidationCheck(
check_id=f"validate_{step.step_id}",
check_type="TABLE_EXISTS",
table=step.table,
description=f"Verify table {step.table} exists",
sql_query=f"SELECT COUNT(*) FROM information_schema.tables WHERE table_name = '{step.table}';",
expected_result=1
)
def _add_column_validation(self, step: MigrationStep) -> ValidationCheck:
"""Create validation for column addition."""
# Extract column name from SQL
column_match = re.search(r'ADD COLUMN (\w+)', step.sql_forward)
column_name = column_match.group(1) if column_match else "unknown"
return ValidationCheck(
check_id=f"validate_{step.step_id}",
check_type="COLUMN_EXISTS",
table=step.table,
description=f"Verify column {column_name} exists in {step.table}",
sql_query=f"SELECT COUNT(*) FROM information_schema.columns WHERE table_name = '{step.table}' AND column_name = '{column_name}';",
expected_result=1
)
def _modify_column_validation(self, step: MigrationStep) -> ValidationCheck:
"""Create validation for column modification."""
return ValidationCheck(
check_id=f"validate_{step.step_id}",
check_type="COLUMN_MODIFIED",
table=step.table,
description=f"Verify column modification in {step.table}",
sql_query=step.validation_sql or f"SELECT 1;", # Use provided validation or default
expected_result=1
)
def _add_index_validation(self, step: MigrationStep) -> ValidationCheck:
"""Create validation for index addition."""
# Extract index name from SQL
index_match = re.search(r'INDEX (\w+)', step.sql_forward)
index_name = index_match.group(1) if index_match else "unknown"
return ValidationCheck(
check_id=f"validate_{step.step_id}",
check_type="INDEX_EXISTS",
table=step.table,
description=f"Verify index {index_name} exists",
sql_query=f"SELECT COUNT(*) FROM information_schema.statistics WHERE index_name = '{index_name}';",
expected_result=1
)
def format_migration_plan_text(plan: MigrationPlan, validations: List[ValidationCheck] = None) -> str:
"""Format migration plan as human-readable text."""
lines = []
lines.append("DATABASE MIGRATION PLAN")
lines.append("=" * 50)
lines.append(f"Migration ID: {plan.migration_id}")
lines.append(f"Created: {plan.created_at}")
lines.append(f"Zero Downtime: {plan.summary['zero_downtime']}")
lines.append("")
# Summary
summary = plan.summary
lines.append("MIGRATION SUMMARY")
lines.append("-" * 17)
lines.append(f"Total Steps: {summary['total_steps']}")
changes = summary['changes_summary']
for change_type, count in changes.items():
if count > 0:
lines.append(f"{change_type.replace('_', ' ').title()}: {count}")
lines.append("")
# Risk Assessment
risk = summary['risk_assessment']
lines.append("RISK ASSESSMENT")
lines.append("-" * 15)
lines.append(f"High Risk Steps: {risk['high_risk_steps']}")
lines.append(f"Medium Risk Steps: {risk['medium_risk_steps']}")
lines.append(f"Low Risk Steps: {risk['low_risk_steps']}")
lines.append("")
# Migration Steps
lines.append("MIGRATION STEPS")
lines.append("-" * 15)
for i, step in enumerate(plan.steps, 1):
lines.append(f"{i}. {step.description} ({step.risk_level} risk)")
lines.append(f" Type: {step.step_type}")
if step.zero_downtime_phase:
lines.append(f" Phase: {step.zero_downtime_phase}")
lines.append(f" Forward SQL: {step.sql_forward}")
lines.append(f" Rollback SQL: {step.sql_rollback}")
if step.estimated_time:
lines.append(f" Estimated Time: {step.estimated_time}")
lines.append("")
# Validation Checks
if validations:
lines.append("VALIDATION CHECKS")
lines.append("-" * 17)
for validation in validations:
lines.append(f"{validation.description}")
lines.append(f" SQL: {validation.sql_query}")
lines.append(f" Expected: {validation.expected_result}")
lines.append("")
return "\n".join(lines)
def main():
parser = argparse.ArgumentParser(description="Generate database migration scripts")
parser.add_argument("--current", "-c", required=True, help="Current schema JSON file")
parser.add_argument("--target", "-t", required=True, help="Target schema JSON file")
parser.add_argument("--output", "-o", help="Output file (default: stdout)")
parser.add_argument("--format", "-f", choices=["json", "text", "sql"], default="text",
help="Output format")
parser.add_argument("--zero-downtime", "-z", action="store_true",
help="Generate zero-downtime migration strategy")
parser.add_argument("--validate-only", "-v", action="store_true",
help="Only generate validation queries")
parser.add_argument("--include-validations", action="store_true",
help="Include validation queries in output")
args = parser.parse_args()
try:
# Load schemas
with open(args.current, 'r') as f:
current_schema = json.load(f)
with open(args.target, 'r') as f:
target_schema = json.load(f)
# Compare schemas
comparator = SchemaComparator()
comparator.load_schemas(current_schema, target_schema)
changes = comparator.compare_schemas()
if not any(changes.values()):
print("No schema changes detected.")
return 0
# Generate migration
generator = MigrationGenerator(zero_downtime=args.zero_downtime)
migration_plan = generator.generate_migration(changes)
# Generate validations if requested
validations = None
if args.include_validations or args.validate_only:
validator = ValidationGenerator()
validations = validator.generate_validations(migration_plan)
# Format output
if args.validate_only:
output = json.dumps([asdict(v) for v in validations], indent=2)
elif args.format == "json":
result = {"migration_plan": asdict(migration_plan)}
if validations:
result["validations"] = [asdict(v) for v in validations]
output = json.dumps(result, indent=2)
elif args.format == "sql":
sql_lines = []
sql_lines.append("-- Database Migration Script")
sql_lines.append(f"-- Migration ID: {migration_plan.migration_id}")
sql_lines.append(f"-- Created: {migration_plan.created_at}")
sql_lines.append("")
for step in migration_plan.steps:
sql_lines.append(f"-- Step: {step.description}")
sql_lines.append(step.sql_forward)
sql_lines.append("")
output = "\n".join(sql_lines)
else: # text format
output = format_migration_plan_text(migration_plan, validations)
# Write output
if args.output:
with open(args.output, 'w') as f:
f.write(output)
else:
print(output)
return 0
except Exception as e:
print(f"Error: {e}", file=sys.stderr)
return 1
if __name__ == "__main__":
sys.exit(main())