#!/usr/bin/env python
#
# Assert users that came from metadata config
#
# Copyright 2013 Hewlett-Packard Development Company, L.P.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
#

from __future__ import print_function

import argparse
import base64
import json
import logging
import os
import subprocess

import MySQLdb

logging.basicConfig()
logger = logging.getLogger('mysql-users')

# Try to connect with no password, ~/.my.cnf and
# /mnt/state/root/metadata.my.cnf
# in that order. This should cover os-refresh-config post install and re-image
try:
    conn = MySQLdb.Connect()
except Exception as e:
    try:
        conn = MySQLdb.Connect(
            read_default_file=os.path.expanduser('~/.my.cnf'))
    except Exception as e:
        conn = MySQLdb.Connect(
            read_default_file='/mnt/state/root/metadata.my.cnf')
cursor = conn.cursor()
rows = cursor.execute("SELECT DISTINCT User FROM mysql.user WHERE user != ''")
existing = set([x[0] for x in cursor.fetchmany(size=rows)])
cursor.close()
should_exist = set()
by_user = {}


def load_userfile(path):
    global should_exist
    global by_user
    if os.path.exists(path):
        with open(path) as dbusers_file:
            db_users = json.load(dbusers_file)
            if not isinstance(db_users, list):
                raise ValueError('%s must be a list' % (path))
            for dbvalues in db_users:
                username = dbvalues['username']
                should_exist.add(username)
                by_user[username] = dbvalues

parser = argparse.ArgumentParser()
parser.add_argument('--noop', '-n', default=False, action='store_true')

opts = parser.parse_args()

load_userfile('/etc/mysql/static-dbusers.json')
load_userfile('/etc/mysql/dbusers.json')

to_delete = existing - should_exist
to_create = should_exist - existing

for createuser in to_create:
    dbvalue = by_user[createuser]

    username = dbvalue['username']
    database = dbvalue.get('database', None)
    privilege = dbvalue.get('privilege', 'ALL')

    if 'password' in dbvalue:
        password = dbvalue['password']
    else:
        password = base64.b64encode(os.urandom(30))

    if database is not None:
        cmd = "GRANT %s ON `%s`.*" % (privilege, database)
    else:
        cmd = "GRANT %s ON *.*" % (privilege)

    cmd_local = ("%s TO `%s`@'localhost' IDENTIFIED BY '%s'" %
                 (cmd, username, password))
    cmd_global = ("%s TO `%s`@'%%' IDENTIFIED BY '%s'" %
                  (cmd, username, password))

    if opts.noop:
        print(cmd_local)
        print(cmd_global)
    else:
        cursor = conn.cursor()
        cursor.execute(cmd_local)
        cursor.execute(cmd_global)
        cursor.execute("FLUSH PRIVILEGES")
        cursor.close()

    if 'userhandle' in dbvalue:
        # Inform Heat of new password for this user
        cmd = ['/opt/aws/bin/cfn-signal', '-i', username,
               '-s', 'true', '--data', password, dbvalue['userhandle']]
        if opts.noop:
            print(cmd)
        else:
            subprocess.check_call(cmd)

if to_delete:
    logger.warn('The following users are not accounted for: %s' % to_delete)
