asyncpg-migrate/asyncpg_migrate/asyncpg_migrate.py
Michal Szczepanski 2a4d58759b
Update asyncpg_migrate with pep8 compatible template
Signed-off-by: Michal Szczepanski <michal@vane.pl>
2020-11-10 01:34:10 +01:00

258 lines
8.2 KiB
Python

#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import sys
import time
import argparse
from typing import List, Dict, Any
import asyncio
import asyncpg
from asyncpg_migrate.migrate_config import MigrateConfig
migration_template = """#!/usr/bin/env python
# -*- coding: utf-8 -*-
import asyncpg
async def up(config):
conn: asyncpg.Connection = await asyncpg.connect(**config)
pass
async def down(config):
conn: asyncpg.Connection = await asyncpg.connect(**config)
pass
"""
# -
# Database
# -
async def ensure_migration_table(config: Dict) -> bool:
"""Create migration table if not exists"""
conn: asyncpg.Connection = await asyncpg.connect(**config)
await conn.execute('''CREATE TABLE IF NOT EXISTS migration (
id SERIAL PRIMARY KEY,
file_name VARCHAR(200) NOT NULL,
date TIMESTAMP NOT NULL DEFAULT NOW()
)''')
return True
async def select_migrations(config: Dict) -> List[asyncpg.Record]:
"""Select all migrations from migration table"""
conn: asyncpg.Connection = await asyncpg.connect(**config)
return await conn.fetch('''SELECT * FROM migration''')
async def select_migration_file_name(config: Dict, file_name) -> asyncpg.Record:
"""Select migration by file name"""
conn: asyncpg.Connection = await asyncpg.connect(**config)
return await conn.fetchrow('''SELECT * FROM migration where file_name = $1''', file_name)
async def insert_migration(config: Dict, file_name):
"""Insert new migration"""
conn: asyncpg.Connection = await asyncpg.connect(**config)
await conn.execute('''INSERT INTO migration (file_name) VALUES ($1)''', file_name)
async def delete_migration(config: Dict, file_name):
"""Delete migration"""
conn: asyncpg.Connection = await asyncpg.connect(**config)
await conn.execute('''DELETE FROM migration where file_name = $1''', file_name)
# -
# Utils
# -
def import_module(directory: str, fname: str) -> Any:
module_name = '.'.join(fname.split('.')[:-1])
m = __import__(f'{directory}.{module_name}')
return getattr(m, module_name)
def print_error(msg: str):
print(f'--ERROR-- {msg}')
def print_info(msg: str):
print(f'{msg}')
# -
# Argument call
# -
def add_migration(name: str, directory: str):
"""Add migration file"""
dt = int(time.time())
name = '_'.join(name.strip().split(' '))
fpath = os.path.join(directory, f'{dt}_{name}.py')
if os.path.exists(fpath):
raise RuntimeError(f'Path {fpath} already exists')
with open(fpath, 'w+') as f:
f.write(migration_template)
async def migrate_database(config: Dict, fname: str, directory: str):
"""Migrate database to latest version"""
await ensure_migration_table(config)
if fname == 'all':
await migrate_all(config, directory)
else:
await migrate_one(config, fname)
async def migrate_all(config: Dict, directory: str):
"""Migrate all missing migrations"""
rows = await select_migrations(config)
for fname in os.listdir(directory):
if fname.endswith('.py') and not fname.startswith('__'):
found = False
for row in rows:
r = dict(list(row.items()))
if r['file_name'] == fname:
found = True
break
if not found:
module = import_module(directory=directory, fname=fname)
await module.up(config)
await insert_migration(config, file_name=fname)
print_info(f'Applying {fname}')
else:
print_info(f'Skipping {fname}')
pass
pass
async def migrate_one(config: Dict, fpath: str):
"""Migrate one file"""
if not os.path.exists(fpath):
print_error(f'Migrate file `{fpath}` not exists')
return
fname = fpath.split('/')[-1]
# check if migration exists
result = await select_migration_file_name(config, file_name=fname)
if result:
print_error(f'Migration `{fpath}` found in migration table')
return
# migrate
directory = '/'.join(fpath.split('/')[:-1])
module = import_module(directory=directory, fname=fname)
await module.up(config)
await insert_migration(config, file_name=fname)
print_info(f'Migrated file `{fpath}`')
async def rollback_migration(config: Dict, fpath: str):
"""Rollback migration"""
if not os.path.exists(fpath):
print_error(f'Rollback file `{fpath}` not exists')
return
fname = fpath.split('/')[-1]
# check if migration exists
await ensure_migration_table(config)
result = await select_migration_file_name(config, file_name=fname)
if not result:
print_error(f'Migration `{fpath}` not found in migration table')
return
directory = '/'.join(fpath.split('/')[:-1])
module = import_module(directory=directory, fname=fname)
await module.down(config)
await delete_migration(config, file_name=fname)
print_info(f'Rollback migration `{fname}`')
async def list_migrations(config: Dict, directory: str):
await ensure_migration_table(config)
rows = await select_migrations(config)
todo = []
done = []
for fname in sorted(os.listdir(directory)):
if fname.endswith('.py') and not fname.startswith('__'):
for row in rows:
r = dict(list(row.items()))
if r['file_name'] == fname:
done.append(fname)
break
else:
todo.append(fname)
[print_info(f'DONE {name}') for name in done]
[print(f'TODO {name}') for name in todo]
def main(args: argparse.Namespace):
"""Main migration method"""
# check if files exists
if args.config and not os.path.exists(args.config):
path_err = f'''Path --config {args.config} not exists
for example:
db:
host: 127.0.0.1
port: 5432
user: postgres
password: postgres
database: asyncpg-migate-test
'''
print_error(path_err)
return
if args.directory and not os.path.exists(args.directory):
print_error(f'Directory --directory {args.directory} not exists')
return
# run migration commands
if args.add_migration is not None:
# add file
add_migration(args.add_migration, args.directory)
elif args.rollback_migration is not None and args.config is not None:
# rollback
config = MigrateConfig.instance()
config.init(args.config)
asyncio.get_event_loop().run_until_complete(
rollback_migration(config.db_config(), args.rollback_migration)
)
elif args.migrate is not None and args.config is not None:
# migrate file or all not migrated
config = MigrateConfig.instance()
config.init(args.config)
asyncio.get_event_loop().run_until_complete(
migrate_database(config.db_config(), args.migrate, args.directory)
)
elif args.list is not None and args.config is not None:
config = MigrateConfig.instance()
config.init(args.config)
asyncio.get_event_loop().run_until_complete(
list_migrations(config.db_config(), args.directory)
)
else:
msg = """
-h - help
-m -c conf/local.yaml - migrate all
-m migrations/1602549074_foo.py -c conf/local.yaml - migrate one
-a foo add file migrations/{timestamp}_foo.py
-r migrations/1602549074_foo.py -c conf/local.yaml - remove one
-l -c conf/local.yaml - list migrations
"""
print_info(msg)
pass
def run():
sys.path.insert(0, os.getcwd())
p = argparse.ArgumentParser()
p.add_argument('-c', '--config', help='Configuration yaml file ex. `conf/db.yaml`', default='conf/db.yaml')
p.add_argument('-d', '--directory', help='Migration directory default to `migrations`', default='migrations')
p.add_argument('-a', '--add-migration', help='Add new migration file with `name of file`')
p.add_argument('-r', '--rollback-migration', help='Rollback migration file based on provided file path')
p.add_argument('-l', '--list', help='List migrations', nargs='?', const='all')
p.add_argument('-m', '--migrate', help='Migrate database optional `file_path`', nargs='?', const='all')
arguments = p.parse_args()
main(arguments)
if __name__ == '__main__':
run()