eoptasks/eoptasks.py

305 lines
11 KiB
Python
Executable File

#! /usr/bin/python3
#
# This script provides parallel remote execution of commands, while having
# some special knownledge of servers that should *not* be handled in parallel.
#
# It defers terminal-handling to tmux(1).
#
# It has some targeting capacities using keywords. Commas for 'OR' and slashes
# for 'AND', ex: ext/test,saas/test/passerelle will select all external test
# servers + all passerelle servers on the SaaS.
#
# It takes any shell command and has some builtin shortcuts such as apt.update
# and apt.upgrade.
#
# Requirements: libtmux and pyyaml.
#
# Configuration: ~/.config/eoptasks.ini
# [config]
# servergroups = /path/to/servergroups.yaml
# ignore = server1, server2
#
# Examples:
#
# eoptasks -k test apt.upgrade
# Run (sudo) apt upgrade on all test servers.
#
# eoptasks -k test,-database sudo apt install python-gadjo
# Run sudo apt install python-gadjo on all test servers except database servers.
#
# eoptasks -k saas/test/passerelle,ext/test --list-servers
# List servers that have saas AND test AND passerelle keywords, OR the
# ext AND test keywords.
import argparse
import configparser
import curses
import json
import os
import random
import re
import socket
import sys
import time
import libtmux
import yaml
class Server:
def __init__(self, servername, group=''):
self.name = servername
self.keywords = set(re.split(r'[-_ \.]', servername + ' ' + group))
self.keywords.add(group)
# add all possible hostname parts as keywords,
# ex: node1.dev.entrouvert.org will add:
# node1.dev, node1.dev.entrouvert, node1.dev.entrouvert.org,
# dev.entrouvert, dev.entrouvert.org, entrouvert.org
parts = servername.split('.')
for i in range(len(parts)-1):
for j in range(i, len(parts)):
if i != j:
self.keywords.add('.'.join(parts[i:j+1]))
def __repr__(self):
return '<Server %s %r>' % (self.name, self.keywords)
def shell(self):
return 'ssh %s' % self.name
def cmd(self, cmd, *args):
return 'ssh -t %s "%s %s"' % (self.name, cmd, ' '.join(args))
def get_servers():
servers = []
config = configparser.ConfigParser()
config.read(os.path.join(os.path.expanduser('~/.config/eoptasks.ini')))
servergroup = config.get('config', 'servergroups', fallback=None)
if servergroup is None:
print("You need to create ~/.config/eoptasks.ini with such a content:\n"
"\n"
" [config]\n"
" servergroups = /home/user/src/puppet/data/servergroups.yaml\n")
sys.exit(1)
ignorelist = [x.strip() for x in config.get('config', 'ignore', fallback='').split(',')]
servergroups = yaml.load(open(servergroup))['servergroups']
for group in servergroups:
for servername in servergroups[group]:
if servername in ignorelist:
continue
servers.append(Server(servername, group))
return servers
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('-l', '--list-servers', action='store_true')
parser.add_argument('--status-window', dest='session_name', type=str)
parser.add_argument('-k', dest='keywords', type=str)
parser.add_argument('cmd', type=str, nargs='?', default=None)
parser.add_argument('args', nargs=argparse.REMAINDER)
args = parser.parse_args()
return args
def filter_servers(servers, args):
selected_servers = []
if args.keywords:
for keyword in args.keywords.split(','):
keywords = set(keyword.split('/'))
selected_servers.extend([
x for x in servers
if keywords.issubset(x.keywords) and not x in selected_servers])
for keyword in args.keywords.split(','):
if keyword.startswith('!') or keyword.startswith('-'):
selected_servers = [x for x in selected_servers if keyword[1:] not in x.keywords]
else:
selected_servers = servers
return selected_servers
def status_window(session_name):
curses.setupterm()
window = curses.initscr()
window.addstr(0, 0, 'eoptasks', curses.A_STANDOUT)
window.addstr(0, 10, '🙂')
curses.curs_set(0)
window.refresh()
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
server_address = '/tmp/.eoptasks.%s' % session_name
sock.bind(server_address)
sock.listen(1)
e = None
while True:
connection, client_address = sock.accept()
try:
json_msg = b''
while True:
data = connection.recv(5000)
if not data:
break
json_msg += data
servers_info = json.loads(json_msg.decode('utf-8'))
finally:
connection.close()
try:
height, width = window.getmaxyx()
max_length = max([len(x) for x in servers_info.keys()]) + 4
nb_columns = (width-4) // max_length
for i, server_name in enumerate(servers_info):
y = 2 + (i//nb_columns)
x = 1 + (width//nb_columns) * (i%nb_columns)
window.addstr(y, x+3, server_name)
status_icon = {
'running': '',
'done': '🆗',
}.get(servers_info[server_name]['status'], '💤')
window.addstr(y, x, status_icon)
if y > height-4:
break
window.refresh()
total_servers = len(servers_info.keys())
running_servers = len([x for x in servers_info.values() if x['status'] == 'running'])
done_servers = len([x for x in servers_info.values() if x['status'] == 'done'])
if total_servers == done_servers:
break
except Exception as e:
window.addstr(0, 10, '😡 %r' % e)
window.refresh()
os.unlink(server_address)
window.addstr(0, 10, '😎')
window.refresh()
time.sleep(5)
args = parse_args()
if args.session_name:
status_window(args.session_name)
sys.exit(0)
servers = get_servers()
selected_servers = filter_servers(servers, args)
if args.list_servers:
for server in sorted(selected_servers, key=lambda x: x.name):
print(server.name)
sys.exit(0)
if not selected_servers:
sys.stderr.write('No matching servers\n')
sys.exit(1)
if not args.cmd:
sys.stderr.write('Missing command\n')
sys.exit(1)
def init_tmux_session():
tmux_session_name = 's%s' % random.randrange(1000)
server_address = '/tmp/.eoptasks.%s' % tmux_session_name
try:
os.unlink(server_address)
except OSError:
pass
os.system('tmux new-session -s %s -n 🌑 -d %s --status-window %s' % (
tmux_session_name, sys.argv[0], tmux_session_name))
return tmux_session_name
tmux_session_name = init_tmux_session()
pid = os.fork()
if pid:
os.system('tmux attach-session -t %s' % tmux_session_name)
else:
def cluster_name(server_name):
return re.match(r'(.*?)(\d*)$', server_name).group(1).replace(
'.rbx.', '.loc.').replace('.gra.', '.loc.').replace('.sbg.', '.loc.')
tmux = libtmux.Server()
session = tmux.find_where({'session_name': tmux_session_name})
cmd = {
'apt.update': 'sudo apt update',
'apt.upgrade': 'sudo apt update && sudo apt full-upgrade -y',
# collectstatic is useful after an upgrade of gadjo.
'collectstatic': '''sudo -u authentic-multitenant authentic2-multitenant-manage collectstatic --noinput;
sudo -u bijoe bijoe-manage collectstatic --noinput;
sudo -u chrono chrono-manage collectstatic --noinput;
sudo -u combo combo-manage collectstatic --noinput;
sudo -u corbo corbo-manage collectstatic --noinput;
sudo -u fargo fargo-manage collectstatic --noinput;
sudo -u hobo hobo-manage collectstatic --noinput;
sudo -u passerelle passerelle-manage collectstatic --noinput;
sudo -u wcs wcs-manage collectstatic;
/bin/true'''.replace('\n', ''),
# combo.reload is useful to get a new {% start_timestamp %} after an
# upgrade of publik-base-theme.
'combo.reload': '''sudo service combo reload; /bin/true''',
# hobo-agent.restart is the fastest way to get the number of threads
# used by celery under control :/
'hobo-agent.restart': '''test -e /etc/hobo-agent/settings.py && sudo supervisorctl restart hobo-agent''',
}.get(args.cmd, args.cmd)
status_window = session.attached_window
all_servers = selected_servers[:]
total_number = len(selected_servers)
random.shuffle(selected_servers)
servers_info = {}
for server in selected_servers:
servers_info[server.name] = {'status': ''}
def send_status():
current_windows = [x.name for x in session.list_windows()]
for server in all_servers:
server_info = servers_info[server.name]
if server.name in current_windows:
server_info['status'] = 'running'
elif server_info['status'] == 'running':
server_info['status'] = 'done'
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
server_address = '/tmp/.eoptasks.%s' % tmux_session_name
try:
sock.connect(server_address)
except socket.error:
return
sock.sendall(json.dumps(servers_info).encode('utf-8'))
sock.close()
while selected_servers:
current_clusters = [cluster_name(x.name) for x in session.list_windows()]
for server in selected_servers[:]:
if cluster_name(server.name) in current_clusters:
continue
selected_servers.remove(server)
session.new_window(
attach=False,
window_name=server.name,
window_shell=server.cmd(cmd, *args.args))
break
else:
time.sleep(0.1)
while len(session.list_windows()) > 10:
send_status()
time.sleep(0.1)
send_status()
percentage = (total_number - len(selected_servers)) / total_number
if percentage == 1:
status_window.rename_window('🌕')
elif percentage >= 0.75:
status_window.rename_window('🌔')
elif percentage >= 0.5:
status_window.rename_window('🌓')
elif percentage >= 0.25:
status_window.rename_window('🌒')
while len(session.list_windows()) > 1:
send_status()
time.sleep(0.1)
status_window.rename_window('🌕')
send_status()