#    Copyright 2014 Mirantis, Inc.
#
#    Licensed under the Apache License, Version 2.0 (the "License"); you may
#    not use this file except in compliance with the License. You may obtain
#    a copy of the License at
#
#         http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
#    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
#    License for the specific language governing permissions and limitations
#    under the License.

import json
import random
from unittest import TestCase

from lxml import etree
import mock

from devops.driver.libvirt.libvirt_xml_builder import LibvirtXMLBuilder
from devops.tests import factories


class BaseTestXMLBuilder(TestCase):

    def setUp(self):
        # TODO(prmtl): make it fuzzy
        self.volume_path = "volume_path_mock"
        self.driver_mock = mock.Mock()
        self.xml_builder = LibvirtXMLBuilder(self.driver_mock)
        self.xml_builder.driver.volume_path = mock.Mock(
            return_value=self.volume_path
        )
        self.xml_builder.driver.network_name = mock.Mock(
            return_value="network_name_mock"
        )
        self.xml_builder.driver.reboot_timeout = None
        self.net = mock.Mock()
        self.node = mock.Mock()
        self.xml_builder.driver.use_hugepages = None
        self.xml_builder.driver.enable_acpi = None

    def _reformat_xml(self, xml):
        """Takes XML in string, parses it and returns pretty printed XML."""
        return etree.tostring(etree.fromstring(xml), pretty_print=True)

    def assertXMLEqual(self, first, second):
        """Compare if two XMLs are equal.

        It parses provided XMLs and converts back to string to minimise
        errors caused by whitespaces.
        """
        first = self._reformat_xml(first)
        second = self._reformat_xml(second)
        # NOTE(prmtl): this assert provide better reporting (diff) in py.test
        assert first == second

    def assertXMLIn(self, member, container):
        """Checks if one XML is included in another XML, dummy way.

        If check fail, it pretty prints both elements
        """
        member = self._reformat_xml(member)
        container = self._reformat_xml(container)

        if member not in container:
            msg = "\n{0}\n\nnot found in\n\n{1}".format(member, container)
            self.fail(msg)

    def assertXMLNotIn(self, member, container):
        """Checks if one XML is not included in another XML, dummy way.

        If check fail, it pretty prints both elements
        """
        member = self._reformat_xml(member)
        container = self._reformat_xml(container)

        if member in container:
            msg = "\n{0}\n\nunexpectedly found in\n\n{1}".format(member,
                                                                 container)
            self.fail(msg)

    def assertXpath(self, xpath, xml):
        """Asserts XPath is valid for given XML."""
        xml = etree.fromstring(xml)
        if not xml.xpath(xpath):
            self.fail('No result for XPath on element\n'
                      'XPath: {xpath}\n'
                      'Element:\n'
                      '{xml}'.format(
                          xpath=xpath,
                          xml=etree.tostring(xml, pretty_print=True)))


class TestNetworkXml(BaseTestXMLBuilder):

    def setUp(self):
        super(TestNetworkXml, self).setUp()
        self.net.name = 'test_name'
        self.net.environment.name = 'test_env_name'
        self.net.forward = None
        self.net.ip_network = None
        self.net.has_dhcp_server = False

    def test_net_name_bridge_name(self):
        self.driver_mock.get_available_device_name.return_value = 'fuelbr0'
        xml = self.xml_builder.build_network_xml(self.net)
        self.assertXMLIn(
            '<name>{0}_{1}</name>'
            ''.format(self.net.environment.name, self.net.name),
            xml)
        self.assertXMLIn('<bridge delay="0" name="fuelbr0" stp="on" />', xml)
        self.driver_mock.get_available_device_name.assert_called_once_with(
            'fuelbr')

    def test_forward(self):
        self.net.forward = "nat"
        xml = self.xml_builder.build_network_xml(self.net)
        self.assertXMLIn(
            '<forward mode="{0}" />'
            ''.format(self.net.forward), xml)

    def test_ip_network(self):
        ip = '172.0.1.1'
        prefix = '24'
        self.net.ip_network = "{0}/{1}".format(ip, prefix)
        self.net.has_pxe_server = False
        self.net.tftp_root_dir = '/tmp'
        xml = self.xml_builder.build_network_xml(self.net)
        string = '<ip address="{0}" prefix="{1}" />'.format(ip, prefix)
        self.assertXMLIn(string, xml)


class TestVolumeXml(BaseTestXMLBuilder):

    def setUp(self):
        super(TestVolumeXml, self).setUp()

    def get_xml(self, volume):
        """Generate XML from volume"""
        return self.xml_builder.build_volume_xml(volume)

    def test_full_volume_xml(self):
        volume = factories.VolumeFactory()
        expected = '''<?xml version="1.0" encoding="utf-8" ?>
<volume>
    <name>{env_name}_{name}</name>
    <capacity>{capacity}</capacity>
    <target>
        <format type="{format}" />
        <permissions>
            <mode>0644</mode>
        </permissions>
    </target>
    <backingStore>
        <path>{path}</path>
        <format type="{store_format}" />
    </backingStore>
</volume>'''.format(
            env_name=volume.environment.name,
            name=volume.name,
            capacity=volume.capacity,
            format=volume.format,
            path=self.volume_path,
            store_format=volume.backing_store.format,
        )
        xml = self.get_xml(volume)
        self.assertXMLEqual(expected, xml)
        self.xml_builder.driver.volume_path.assert_called_with(
            volume.backing_store)

    def test_name_without_env(self):
        volume = factories.VolumeFactory(environment=None)
        xml = self.get_xml(volume)
        self.assertXMLIn('<name>{0}</name>'.format(volume.name), xml)

    def test_no_backing_store(self):
        volume = factories.VolumeFactory(backing_store=None)
        xml = self.get_xml(volume)
        self.assertXpath("not(//backingStore)", xml)

    def test_backing_store(self):
        store_format = "raw"
        volume = factories.VolumeFactory(backing_store__format=store_format)
        xml = self.get_xml(volume)
        self.assertXMLIn('''
    <backingStore>
        <path>{path}</path>
        <format type="{format}" />
    </backingStore>'''.format(path=self.volume_path, format=store_format), xml)


class TestSnapshotXml(BaseTestXMLBuilder):

    def setUp(self):
        super(TestSnapshotXml, self).setUp()
        self.node.name = factories.fuzzy_string('testname_')
        self.node.environment.name = factories.fuzzy_string('testenv_')
        self.disk1_volume_path = factories.fuzzy_string('/volumes/')
        self.disk1 = mock.Mock()
        self.disk1.device = 'disk'
        self.disk1.target_dev = factories.fuzzy_string()
        self.disk1.volume.get_path.return_value = self.disk1_volume_path
        self.node.disk_devices = [self.disk1]

    def domain_set_active(self, active=True):
        self.domain = mock.Mock()
        self.domain.isActive.return_value = active
        self.xml_builder.driver.conn.lookupByUUIDString.return_value = \
            self.domain

    def check_snaphot_xml(self, name, description, expected, disk_only=False,
                          external=False, external_dir=None):
        result = self.xml_builder.build_snapshot_xml(name, description,
                                                     node=self.node,
                                                     disk_only=disk_only,
                                                     external=external,
                                                     external_dir=external_dir)
        self.assertXMLIn(expected, result)

    def test_no_name(self):
        name = None
        description = factories.fuzzy_string('test_description_')
        expected = '''
<domainsnapshot>
    <description>{0}</description>
</domainsnapshot>'''.format(description)
        self.check_snaphot_xml(name, description, expected)

    def test_external_domain_not_active(self):
        self.domain_set_active(False)
        name = factories.fuzzy_string('snapshot_')
        external_dir = factories.fuzzy_string('/extsnap/')
        description = None
        expected = '''
<domainsnapshot>
    <name>{0}</name>
    <memory snapshot="no"/>
    <disks>
        <disk name="{2}" snapshot="external">
            <source file="{1}"/>
        </disk>
    </disks>
</domainsnapshot>'''.format(name, self.disk1_volume_path,
                            self.disk1.target_dev)
        self.check_snaphot_xml(name, description, expected, disk_only=False,
                               external=True, external_dir=external_dir)

    def test_external_disk_only(self):
        self.domain_set_active(True)
        name = factories.fuzzy_string('snapshot_')
        external_dir = factories.fuzzy_string('/extsnap/')
        description = None
        expected = '''
<domainsnapshot>
    <name>{0}</name>
    <memory snapshot="no"/>
    <disks>
        <disk name="{2}" snapshot="external">
            <source file="{1}"/>
        </disk>
    </disks>
</domainsnapshot>'''.format(name, self.disk1_volume_path,
                            self.disk1.target_dev)
        self.check_snaphot_xml(name, description, expected, disk_only=True,
                               external=True, external_dir=external_dir)

    def test_external_snapshot(self):
        self.domain_set_active(True)
        name = factories.fuzzy_string('snapshot_')
        external_dir = factories.fuzzy_string('/extsnap/')
        description = factories.fuzzy_string('test_description_')
        expected = '''
<domainsnapshot>
    <name>{0}</name>
    <description>{1}</description>
    <memory file="{2}/snapshot-memory-{3}_{4}.{5}" snapshot="external"/>
    <disks>
        <disk name="{7}" snapshot="external">
            <source file="{6}"/>
        </disk>
    </disks>
</domainsnapshot>'''.format(name, description, external_dir,
                            self.node.environment.name, self.node.name, name,
                            self.disk1_volume_path, self.disk1.target_dev)
        self.check_snaphot_xml(name, description, expected, disk_only=False,
                               external=True, external_dir=external_dir)

    @mock.patch('devops.driver.libvirt.libvirt_xml_builder.os')
    def test_external_snapshot_memory_snapshot_exists(self, mock_os):
        self.domain_set_active(True)
        name = factories.fuzzy_string('snapshot_')
        external_dir = factories.fuzzy_string('/extsnap/')
        description = factories.fuzzy_string('test_description_')
        mem_filename = "{0}/snapshot-memory-{1}_{2}.{3}".format(
            external_dir, self.node.environment.name,
            self.node.name, name)

        def mock_exists(*args):
            return True if args[0] == mem_filename else False
        mock_os.path.exists = mock_exists
        expected = '''
<domainsnapshot>
    <name>{0}</name>
    <description>{1}</description>
    <memory file="{2}/snapshot-memory-{3}_{4}.{5}-0" snapshot="external"/>
    <disks>
        <disk name="{7}" snapshot="external">
            <source file="{6}"/>
        </disk>
    </disks>
</domainsnapshot>'''.format(name, description, external_dir,
                            self.node.environment.name, self.node.name, name,
                            self.disk1_volume_path, self.disk1.target_dev)
        self.check_snaphot_xml(name, description, expected, disk_only=False,
                               external=True, external_dir=external_dir)

    def test_no_description(self):
        name = factories.fuzzy_string('test_snapshot_')
        description = None
        expected = '''
<domainsnapshot>
    <name>{0}</name>
</domainsnapshot>'''.format(name)
        self.check_snaphot_xml(name, description, expected)

    def test_nothing_there(self):
        name = None
        description = None
        expected = '<domainsnapshot />'
        self.check_snaphot_xml(name, description, expected)

    def test_snapshot(self):
        name = factories.fuzzy_string('test_snapshot_')
        description = factories.fuzzy_string('test_description_')
        expected = '''
<domainsnapshot>
    <name>{0}</name>
    <description>{1}</description>
</domainsnapshot>'''.format(name, description)
        self.check_snaphot_xml(name, description, expected)


class TestNodeXml(BaseTestXMLBuilder):

    def setUp(self):
        super(TestNodeXml, self).setUp()

        self.node.hypervisor = 'test_hypervisor'
        self.node.name = 'test_name'
        self.node.environment.name = 'test_env_name'
        self.node.vcpu = random.randint(1, 10)
        self.node.memory = random.randint(128, 1024)
        self.node.os_type = 'test_os_type'
        self.node.architecture = 'test_architecture'
        self.node.boot = '["dev1", "dev2"]'
        self.node.has_vnc = None
        self.node.should_enable_boot_menu = False
        disk_devices = mock.MagicMock()
        disk_devices.filter.return_value = []
        self.node.disk_devices = disk_devices
        self.node.interfaces = []

    def test_node(self):
        xml = self.xml_builder.build_node_xml(self.node, 'test_emulator', [])
        boot = json.loads(self.node.boot)
        expected = '''
<domain type="test_hypervisor">
    <name>test_env_name_test_name</name>
    <cpu mode="host-passthrough" />
    <vcpu>{0}</vcpu>
    <memory unit="KiB">{1}</memory>
    <clock offset="utc" />
    <clock>
        <timer name="rtc" tickpolicy="catchup" track="wall">
            <catchup limit="10000" slew="120" threshold="123" />
        </timer>
    </clock>
    <clock>
        <timer name="pit" tickpolicy="delay" />
    </clock>
    <clock>
        <timer name="hpet" present="yes" />
    </clock>
    <os>
        <type arch="{2}">{3}</type>
        <boot dev="{4}" />
        <boot dev="{5}" />
    </os>
    <devices>
        <controller model="nec-xhci" type="usb" />
        <emulator>test_emulator</emulator>
        <video>
            <model heads="1" type="vga" vram="9216" />
        </video>
        <serial type="pty">
            <target port="0" />
        </serial>
        <console type="pty">
            <target port="0" type="serial" />
        </console>
    </devices>
</domain>'''.format(self.node.vcpu, str(self.node.memory * 1024),
                    self.node.architecture, self.node.os_type,
                    boot[0], boot[1])
        self.assertXMLIn(expected, xml)

    def test_node_with_numa(self):
        self.node.vcpu = 4
        self.node.memory = 1024
        numa = [
            {
                'cpus': '0,1',
                'memory': 512 * 1024
            },
            {
                'cpus': '2,3',
                'memory': 512 * 1024
            }
        ]
        xml = self.xml_builder.build_node_xml(self.node, 'test_emulator', numa)
        boot = json.loads(self.node.boot)
        expected = '''
<domain type="test_hypervisor">
    <name>test_env_name_test_name</name>
    <cpu mode="host-passthrough">
        <numa>
            <cell cpus="0,1" memory="524288"/>
            <cell cpus="2,3" memory="524288"/>
        </numa>
    </cpu>
    <vcpu>{0}</vcpu>
    <memory unit="KiB">{1}</memory>
    <clock offset="utc" />
    <clock>
        <timer name="rtc" tickpolicy="catchup" track="wall">
            <catchup limit="10000" slew="120" threshold="123" />
        </timer>
    </clock>
    <clock>
        <timer name="pit" tickpolicy="delay" />
    </clock>
    <clock>
        <timer name="hpet" present="yes" />
    </clock>
    <os>
        <type arch="{2}">{3}</type>
        <boot dev="{4}" />
        <boot dev="{5}" />
    </os>
    <devices>
        <controller model="nec-xhci" type="usb" />
        <emulator>test_emulator</emulator>
        <video>
            <model heads="1" type="vga" vram="9216" />
        </video>
        <serial type="pty">
            <target port="0" />
        </serial>
        <console type="pty">
            <target port="0" type="serial" />
        </console>
    </devices>
</domain>'''.format(self.node.vcpu, str(self.node.memory * 1024),
                    self.node.architecture, self.node.os_type,
                    boot[0], boot[1])
        self.assertXMLIn(expected, xml)

    @mock.patch('devops.driver.libvirt.libvirt_xml_builder.uuid')
    def test_node_devices(self, mock_uuid):
        mock_uuid.uuid4.return_value.hex = 'disk-serial'
        volumes = [mock.Mock(uuid=i, format='frmt{0}'.format(i))
                   for i in range(3)]

        disk_devices = [
            mock.Mock(
                type='type{0}'.format(i),
                device='device{0}'.format(i),
                volume=volumes[i],
                target_dev='tdev{0}'.format(i),
                bus='bus{0}'.format(i)
            ) for i in range(3)
        ]
        self.node.disk_devices = disk_devices
        xml = self.xml_builder.build_node_xml(self.node, 'test_emulator', [])
        expected = '''
    <devices>
        <controller model="nec-xhci" type="usb" />
        <emulator>test_emulator</emulator>
        <disk device="device0" type="type0">
            <driver cache="unsafe" type="frmt0" />
            <source file="volume_path_mock" />
            <target bus="bus0" dev="tdev0" />
            <serial>disk-serial</serial>
        </disk>
        <disk device="device1" type="type1">
            <driver cache="unsafe" type="frmt1" />
            <source file="volume_path_mock" />
            <target bus="bus1" dev="tdev1" />
            <serial>disk-serial</serial>
        </disk>
        <disk device="device2" type="type2">
            <driver cache="unsafe" type="frmt2" />
            <source file="volume_path_mock" />
            <target bus="bus2" dev="tdev2" />
            <serial>disk-serial</serial>
        </disk>
        <video>
            <model heads="1" type="vga" vram="9216" />
        </video>
        <serial type="pty">
            <target port="0" />
        </serial>
        <console type="pty">
            <target port="0" type="serial" />
        </console>
    </devices>'''
        self.assertXMLIn(expected, xml)

    @mock.patch('devops.driver.libvirt.libvirt_xml_builder.uuid')
    def test_node_multipath_devices(self, mock_uuid):
        mock_uuid.uuid4.return_value.hex = 'disk-serial'
        volumes = [mock.Mock(uuid=i, format='frmt{0}'.format(i))
                   for i in range(3)]

        disk_devices = [
            mock.Mock(
                type='type{0}'.format(i),
                device='device{0}'.format(i),
                volume=volumes[i],
                target_dev='tdev{0}'.format(i),
                bus='bus{0}'.format(i)
            ) for i in range(3)
        ]

        self.node.disk_devices = disk_devices + disk_devices
        xml = self.xml_builder.build_node_xml(self.node, 'test_emulator', [])
        expected = '''
    <devices>
        <controller model="nec-xhci" type="usb" />
        <emulator>test_emulator</emulator>
        <disk device="device0" type="type0">
            <driver cache="unsafe" type="frmt0"/>
            <source file="volume_path_mock"/>
            <target bus="bus0" dev="tdev0"/>
            <serial>disk-serial</serial>
            <wwn>0disk-serial</wwn>
        </disk>
        <disk device="device1" type="type1">
            <driver cache="unsafe" type="frmt1"/>
            <source file="volume_path_mock"/>
            <target bus="bus1" dev="tdev1"/>
            <serial>disk-serial</serial>
            <wwn>0disk-serial</wwn>
        </disk>
        <disk device="device2" type="type2">
            <driver cache="unsafe" type="frmt2"/>
            <source file="volume_path_mock"/>
            <target bus="bus2" dev="tdev2"/>
            <serial>disk-serial</serial>
            <wwn>0disk-serial</wwn>
        </disk>
        <disk device="device0" type="type0">
            <driver cache="unsafe" type="frmt0"/>
            <source file="volume_path_mock"/>
            <target bus="bus0" dev="tdev0"/>
            <serial>disk-serial</serial>
            <wwn>0disk-serial</wwn>
        </disk>
        <disk device="device1" type="type1">
            <driver cache="unsafe" type="frmt1"/>
            <source file="volume_path_mock"/>
            <target bus="bus1" dev="tdev1"/>
            <serial>disk-serial</serial>
            <wwn>0disk-serial</wwn>
        </disk>
        <disk device="device2" type="type2">
            <driver cache="unsafe" type="frmt2"/>
            <source file="volume_path_mock"/>
            <target bus="bus2" dev="tdev2"/>
            <serial>disk-serial</serial>
            <wwn>0disk-serial</wwn>
        </disk>
        <video>
            <model heads="1" type="vga" vram="9216" />
        </video>
        <serial type="pty">
            <target port="0" />
        </serial>
        <console type="pty">
            <target port="0" type="serial" />
        </console>
    </devices>'''
        self.assertXMLIn(expected, xml)

    def test_node_interfaces(self):
        networks = [mock.Mock(uuid=i,
                    environment=self.node.environment) for i in range(3)]
        for num, net in enumerate(networks):
            net.configure_mock(**{'name': 'network_name_mock_{0}'.format(num)})
        self.node.interfaces = [
            mock.Mock(type='network', mac_address='mac{0}'.format(i),
                      network=networks[i],
                      model='model{0}'.format(i)) for i in range(3)]
        xml = self.xml_builder.build_node_xml(self.node, 'test_emulator', [])
        self.assertXMLIn('''
    <devices>
        <controller model="nec-xhci" type="usb" />
        <emulator>test_emulator</emulator>
        <interface type="network">
            <mac address="mac0" />
            <source network="network_name_mock" />
            <model type="model0" />
            <filterref filter="test_env_name_network_name_mock_0_mac0"/>
        </interface>
        <interface type="network">
            <mac address="mac1" />
            <source network="network_name_mock" />
            <model type="model1" />
            <filterref filter="test_env_name_network_name_mock_1_mac1"/>
        </interface>
        <interface type="network">
            <mac address="mac2" />
            <source network="network_name_mock" />
            <model type="model2" />
            <filterref filter="test_env_name_network_name_mock_2_mac2"/>
        </interface>
        <video>
            <model heads="1" type="vga" vram="9216" />
        </video>
        <serial type="pty">
            <target port="0" />
        </serial>
        <console type="pty">
            <target port="0" type="serial" />
        </console>
    </devices>''', xml)
