#!/usr/bin/python3
# Simple HTTP web server that forwards prometheus alerts over XMPP.
#
# To use, configure a web hook in alertmanager. E.g.:
#
# receivers:
# - name: 'jelmer-pager'
#   webhook_configs:
#   - url: 'http://192.168.2.1:9199/alert'
#
# Edit xmpp-alerts.yml.example, then run:
# $ python3 prometheus-xmpp-alerts --config=xmpp-alerts.yml.example
import argparse
import json
import logging
import shlex
import socket
import sys

import slixmpp
import yaml
from aiohttp import web
from prometheus_client import (
    Counter,
    generate_latest,
    CONTENT_TYPE_LATEST,
)

from prometheus_xmpp import (
    create_message_short,
    create_message_full,
    run_amtool,
)


DEFAULT_CONF_PATH = '/etc/prometheus/xmpp-alerts.yml'

alert_counter = Counter('alert_count', 'Total number of alerts delivered')
test_counter = Counter('test_count', 'Total number of test alerts delivered')
xmpp_message_counter = Counter(
    'xmpp_message_count', 'Total number of XMPP messages received.')


class XmppApp(slixmpp.ClientXMPP):

    def __init__(self, jid, password, amtool_allowed=None):
        slixmpp.ClientXMPP.__init__(self, jid, password)
        self._amtool_allowed = amtool_allowed or []
        self.auto_authorize = True
        self.add_event_handler("session_start", self.start)
        self.add_event_handler("message", self.message)
        self.add_event_handler("disconnected", self.lost)
        self.register_plugin('xep_0030')  # Service Discovery
        self.register_plugin('xep_0004')  # Data Forms
        self.register_plugin('xep_0060')  # PubSub
        self.register_plugin('xep_0199')  # XMPP Ping

    def start(self, event):
        """Process the session_start event.

        Args:
          event: Event data (empty)
        """
        self.send_presence(ptype='available', pstatus='Active')
        self.get_roster()

    def lost(self, event):
        logging.info("Connection lost, exiting.")
        sys.exit(1)

    def message(self, msg):
        """Handle an incoming message.

        Args:
            msg: The received message stanza.
        """
        if msg['type'] in ('chat', 'normal'):
            args = shlex.split(msg['body'])
            if args == []:
                response = "No command specified"
            elif args[0] in ('alert', 'silence'):
                if msg['from'].bare in self._amtool_allowed:
                    response = run_amtool(args)
                else:
                    response = "Unauthorized JID."
            elif args[0] == 'help':
                response = "Supported commands: help, alert, silence."
            else:
                response = "Unknown command: %s" % args[0]
            msg.reply(response).send()


parser = argparse.ArgumentParser()
parser.add_argument('--config', dest='config_path',
                    type=str, default=DEFAULT_CONF_PATH,
                    help='Path to configuration file.')
parser.add_argument("-q", "--quiet", help="set logging to ERROR",
                    action="store_const", dest="loglevel",
                    const=logging.ERROR, default=logging.INFO)
parser.add_argument("-d", "--debug", help="set logging to DEBUG",
                    action="store_const", dest="loglevel",
                    const=logging.DEBUG, default=logging.INFO)

args = parser.parse_args()

# Setup logging.
logging.basicConfig(level=args.loglevel, format='%(levelname)-8s %(message)s')

with open(args.config_path) as f:
    if getattr(yaml, 'FullLoader', None):
        config = yaml.load(f, Loader=yaml.FullLoader)
    else:
        # Backwards compatibility with older versions of Python
        config = yaml.load(f)

hostname = socket.gethostname()
jid = "{}/{}".format(config['jid'], hostname)

app = XmppApp(jid, config.get('password'),
              config.get('amtool_allowed', [config['to_jid']]))
app.connect()


async def serve_test(request):
    test_counter.inc()
    id_ = app.send_message(
            mto=config['to_jid'],
            mbody='Test message',
            mtype='chat')
    return web.Response(body='Sent message.')


async def serve_alert(request):
    alert_counter.inc()
    try:
        alert = await request.json()
    except json.decoder.JSONDecodeError as e:
        raise web.HTTPUnprocessableEntity(str(e))
    if 'format' in config and config['format'] == 'full':
        text = '\n--\n'.join(create_message_full(alert))
    else:
        text = '\n'.join(create_message_short(alert))
    id_ = app.send_message(
            mto=config['to_jid'],
            mbody=text,
            mtype='chat')
    return web.Response(body='Sent message')


async def serve_metrics(request):
    resp = web.Response(body=generate_latest())
    resp.content_type = CONTENT_TYPE_LATEST
    return resp


async def serve_root(request):
    return web.Response(body='See /test, /alert or /metrics')


web_app = web.Application()
web_app.add_routes([
    web.get('/', serve_root),
    web.get('/test', serve_test),
    web.post('/test', serve_test),
    web.get('/alert', serve_alert),
    web.post('/alert', serve_alert),
    web.get('/metrics', serve_metrics),
])

web.run_app(web_app, host=config['listen_address'], port=config['listen_port'])
