# Copyright 2015 Hewlett-Packard Development Company, L.P.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
#
from unittest import mock

from cryptography import fernet
from oslo_config import cfg
from oslo_config import fixture as oslo_fixture
from oslo_utils import uuidutils
from taskflow.types import failure

from octavia.amphorae.driver_exceptions import exceptions as driver_except
from octavia.common import constants
from octavia.common import data_models
from octavia.common import utils
from octavia.controller.worker.v1.tasks import amphora_driver_tasks
from octavia.db import repositories as repo
import octavia.tests.unit.base as base


AMP_ID = uuidutils.generate_uuid()
COMPUTE_ID = uuidutils.generate_uuid()
LISTENER_ID = uuidutils.generate_uuid()
LB_ID = uuidutils.generate_uuid()
CONN_MAX_RETRIES = 10
CONN_RETRY_INTERVAL = 6
FAKE_CONFIG_FILE = 'fake config file'

_amphora_mock = mock.MagicMock()
_amphora_mock.id = AMP_ID
_amphora_mock.status = constants.AMPHORA_ALLOCATED
_amphora_mock.vrrp_ip = '198.51.100.65'
_load_balancer_mock = mock.MagicMock()
_load_balancer_mock.id = LB_ID
_listener_mock = mock.MagicMock()
_listener_mock.id = LISTENER_ID
_load_balancer_mock.listeners = [_listener_mock]
_vip_mock = mock.MagicMock()
_load_balancer_mock.vip = _vip_mock
_LB_mock = mock.MagicMock()
_amphorae_mock = [_amphora_mock]
_amphora_network_config_mock = mock.MagicMock()
_amphorae_network_config_mock = {
    _amphora_mock.id: _amphora_network_config_mock}
_network_mock = mock.MagicMock()
_port_mock = mock.MagicMock()
_ports_mock = [_port_mock]
_session_mock = mock.MagicMock()


@mock.patch('octavia.db.repositories.AmphoraRepository.update')
@mock.patch('octavia.db.repositories.ListenerRepository.update')
@mock.patch('octavia.db.repositories.ListenerRepository.get',
            return_value=_listener_mock)
@mock.patch('octavia.db.api.get_session', return_value=_session_mock)
@mock.patch('octavia.controller.worker.v1.tasks.amphora_driver_tasks.LOG')
@mock.patch('oslo_utils.uuidutils.generate_uuid', return_value=AMP_ID)
@mock.patch('stevedore.driver.DriverManager.driver')
class TestAmphoraDriverTasks(base.TestCase):

    def setUp(self):

        _LB_mock.amphorae = [_amphora_mock]
        _LB_mock.id = LB_ID
        conf = oslo_fixture.Config(cfg.CONF)
        conf.config(group="haproxy_amphora",
                    active_connection_max_retries=CONN_MAX_RETRIES)
        conf.config(group="haproxy_amphora",
                    active_connection_rety_interval=CONN_RETRY_INTERVAL)
        conf.config(group="controller_worker",
                    loadbalancer_topology=constants.TOPOLOGY_SINGLE)
        self.timeout_dict = {constants.REQ_CONN_TIMEOUT: 1,
                             constants.REQ_READ_TIMEOUT: 2,
                             constants.CONN_MAX_RETRIES: 3,
                             constants.CONN_RETRY_INTERVAL: 4}
        super().setUp()

    @mock.patch('octavia.db.repositories.LoadBalancerRepository.get')
    def test_amp_listeners_update(self,
                                  mock_lb_repo_get,
                                  mock_driver,
                                  mock_generate_uuid,
                                  mock_log,
                                  mock_get_session,
                                  mock_listener_repo_get,
                                  mock_listener_repo_update,
                                  mock_amphora_repo_update):

        mock_lb_repo_get.return_value = _LB_mock
        amp_list_update_obj = amphora_driver_tasks.AmpListenersUpdate()
        amp_list_update_obj.execute(_load_balancer_mock, _amphora_mock,
                                    self.timeout_dict)

        mock_driver.update_amphora_listeners.assert_called_once_with(
            _LB_mock, _amphora_mock, self.timeout_dict)

        mock_driver.update_amphora_listeners.side_effect = Exception('boom')

        amp_list_update_obj.execute(_load_balancer_mock, _amphora_mock,
                                    self.timeout_dict)

        mock_amphora_repo_update.assert_called_once_with(
            _session_mock, AMP_ID, status=constants.ERROR)

    @mock.patch('octavia.db.repositories.LoadBalancerRepository.get')
    def test_amphorae_listeners_update(self,
                                       mock_lb_repo_get,
                                       mock_driver,
                                       mock_generate_uuid,
                                       mock_log,
                                       mock_get_session,
                                       mock_listener_repo_get,
                                       mock_listener_repo_update,
                                       mock_amphora_repo_update):

        mock_lb_repo_get.return_value = _LB_mock
        amphorae_status = {
            _amphora_mock.id: {
                constants.UNREACHABLE: False
            }
        }

        amp_list_update_obj = amphora_driver_tasks.AmphoraIndexListenerUpdate()
        amp_list_update_obj.execute(_load_balancer_mock, 0,
                                    [_amphora_mock], amphorae_status,
                                    _amphora_mock.id,
                                    self.timeout_dict)

        mock_driver.update_amphora_listeners.assert_called_once_with(
            _LB_mock, _amphora_mock, self.timeout_dict)

        # Unreachable amp
        mock_driver.reset_mock()
        amphorae_status = {
            _amphora_mock.id: {
                constants.UNREACHABLE: True
            }
        }
        amp_list_update_obj.execute(_LB_mock, 0, [_amphora_mock],
                                    amphorae_status,
                                    _amphora_mock.id,
                                    self.timeout_dict)
        mock_driver.update_amphora_listeners.assert_not_called()

        # Test exception
        mock_driver.update_amphora_listeners.side_effect = Exception('boom')

        amp_list_update_obj.execute(_load_balancer_mock, 0,
                                    [_amphora_mock], {},
                                    _amphora_mock.id,
                                    self.timeout_dict)

        mock_amphora_repo_update.assert_called_once_with(
            _session_mock, AMP_ID, status=constants.ERROR)

        # Test exception, secondary amp
        mock_amphora_repo_update.reset_mock()
        mock_driver.update_amphora_listeners.side_effect = Exception('boom')

        amp_list_update_obj.execute(_load_balancer_mock, 0,
                                    [_amphora_mock], {},
                                    '1234',
                                    self.timeout_dict)

        mock_amphora_repo_update.assert_not_called()

    def test_listener_update(self,
                             mock_driver,
                             mock_generate_uuid,
                             mock_log,
                             mock_get_session,
                             mock_listener_repo_get,
                             mock_listener_repo_update,
                             mock_amphora_repo_update):

        listener_update_obj = amphora_driver_tasks.ListenersUpdate()
        listener_update_obj.execute(_load_balancer_mock)

        mock_driver.update.assert_called_once_with(_load_balancer_mock)

        # Test the revert
        amp = listener_update_obj.revert(_load_balancer_mock)
        repo.ListenerRepository.update.assert_called_once_with(
            _session_mock,
            id=LISTENER_ID,
            provisioning_status=constants.ERROR)
        self.assertIsNone(amp)

        # Test the revert with exception
        repo.ListenerRepository.update.reset_mock()
        mock_listener_repo_update.side_effect = Exception('fail')
        amp = listener_update_obj.revert(_load_balancer_mock)
        repo.ListenerRepository.update.assert_called_once_with(
            _session_mock,
            id=LISTENER_ID,
            provisioning_status=constants.ERROR)
        self.assertIsNone(amp)

    def test_listeners_update(self,
                              mock_driver,
                              mock_generate_uuid,
                              mock_log,
                              mock_get_session,
                              mock_listener_repo_get,
                              mock_listener_repo_update,
                              mock_amphora_repo_update):
        listeners_update_obj = amphora_driver_tasks.ListenersUpdate()
        listeners = [data_models.Listener(id='listener1'),
                     data_models.Listener(id='listener2')]
        vip = data_models.Vip(ip_address='10.0.0.1')
        lb = data_models.LoadBalancer(id='lb1', listeners=listeners, vip=vip)
        listeners_update_obj.execute(lb)
        mock_driver.update.assert_called_once_with(lb)
        self.assertEqual(1, mock_driver.update.call_count)

        # Test the revert
        amp = listeners_update_obj.revert(lb)
        expected_db_calls = [mock.call(_session_mock,
                                       id=listeners[0].id,
                                       provisioning_status=constants.ERROR),
                             mock.call(_session_mock,
                                       id=listeners[1].id,
                                       provisioning_status=constants.ERROR)]
        repo.ListenerRepository.update.has_calls(expected_db_calls)
        self.assertEqual(2, repo.ListenerRepository.update.call_count)
        self.assertIsNone(amp)

    @mock.patch('octavia.controller.worker.task_utils.TaskUtils.'
                'mark_listener_prov_status_error')
    def test_amphora_index_listeners_reload(
            self, mock_prov_status_error, mock_driver, mock_generate_uuid,
            mock_log, mock_get_session, mock_listener_repo_get,
            mock_listener_repo_update, mock_amphora_repo_update):
        amphora_mock = mock.MagicMock()
        listeners_reload_obj = (
            amphora_driver_tasks.AmphoraIndexListenersReload())
        mock_lb = mock.MagicMock()
        mock_listener = mock.MagicMock()
        mock_listener.id = '12345'
        mock_driver.reload.side_effect = [mock.DEFAULT, Exception('boom')]

        # Test no listeners
        mock_lb.listeners = None
        listeners_reload_obj.execute(mock_lb, 0, None, {}, amphora_mock.id)
        mock_driver.reload.assert_not_called()

        # Test with listeners
        amphorae_status = {
            _amphora_mock.id: {
                constants.UNREACHABLE: False
            }
        }
        mock_driver.start.reset_mock()
        mock_lb.listeners = [mock_listener]
        listeners_reload_obj.execute(mock_lb, 0, [amphora_mock],
                                     amphorae_status,
                                     amphora_mock.id,
                                     timeout_dict=self.timeout_dict)
        mock_driver.reload.assert_called_once_with(mock_lb, amphora_mock,
                                                   self.timeout_dict)

        # Unreachable amp
        amphorae_status = {
            _amphora_mock.id: {
                constants.UNREACHABLE: True
            }
        }
        mock_driver.reload.reset_mock()
        listeners_reload_obj.execute(mock_lb, 0, [_amphora_mock],
                                     amphorae_status,
                                     _amphora_mock.id,
                                     timeout_dict=self.timeout_dict)
        mock_driver.reload.assert_not_called()

        # Test with reload exception
        mock_driver.reload.reset_mock()
        listeners_reload_obj.execute(mock_lb, 0, [amphora_mock], {},
                                     amphora_mock.id,
                                     timeout_dict=self.timeout_dict)
        mock_driver.reload.assert_called_once_with(mock_lb, amphora_mock,
                                                   self.timeout_dict)
        mock_amphora_repo_update.assert_called_once_with(
            _session_mock, amphora_mock.id, status=constants.ERROR)

        # Test with reload exception, secondary amp
        mock_driver.reload.reset_mock()
        mock_amphora_repo_update.reset_mock()
        listeners_reload_obj.execute(mock_lb, 0, [_amphora_mock], {},
                                     '1234',
                                     timeout_dict=self.timeout_dict)
        mock_driver.reload.assert_called_once_with(mock_lb, _amphora_mock,
                                                   self.timeout_dict)
        mock_amphora_repo_update.assert_not_called()

    @mock.patch('octavia.controller.worker.task_utils.TaskUtils.'
                'mark_listener_prov_status_error')
    def test_listeners_start(self,
                             mock_prov_status_error,
                             mock_driver,
                             mock_generate_uuid,
                             mock_log,
                             mock_get_session,
                             mock_listener_repo_get,
                             mock_listener_repo_update,
                             mock_amphora_repo_update):
        listeners_start_obj = amphora_driver_tasks.ListenersStart()
        mock_lb = mock.MagicMock()
        mock_listener = mock.MagicMock()
        mock_listener.id = '12345'

        # Test no listeners
        mock_lb.listeners = None
        listeners_start_obj.execute(mock_lb)
        mock_driver.start.assert_not_called()

        # Test with listeners
        mock_driver.start.reset_mock()
        mock_lb.listeners = [mock_listener]
        listeners_start_obj.execute(mock_lb)
        mock_driver.start.assert_called_once_with(mock_lb, None)

        # Test revert
        mock_lb.listeners = [mock_listener]
        listeners_start_obj.revert(mock_lb)
        mock_prov_status_error.assert_called_once_with('12345')

    def test_listener_delete(self,
                             mock_driver,
                             mock_generate_uuid,
                             mock_log,
                             mock_get_session,
                             mock_listener_repo_get,
                             mock_listener_repo_update,
                             mock_amphora_repo_update):

        listener_delete_obj = amphora_driver_tasks.ListenerDelete()
        listener_delete_obj.execute(_listener_mock)

        mock_driver.delete.assert_called_once_with(_listener_mock)

        # Test the revert
        amp = listener_delete_obj.revert(_listener_mock)
        repo.ListenerRepository.update.assert_called_once_with(
            _session_mock,
            id=LISTENER_ID,
            provisioning_status=constants.ERROR)
        self.assertIsNone(amp)

        # Test the revert with exception
        repo.ListenerRepository.update.reset_mock()
        mock_listener_repo_update.side_effect = Exception('fail')
        amp = listener_delete_obj.revert(_listener_mock)
        repo.ListenerRepository.update.assert_called_once_with(
            _session_mock,
            id=LISTENER_ID,
            provisioning_status=constants.ERROR)
        self.assertIsNone(amp)

    def test_amphora_get_info(self,
                              mock_driver,
                              mock_generate_uuid,
                              mock_log,
                              mock_get_session,
                              mock_listener_repo_get,
                              mock_listener_repo_update,
                              mock_amphora_repo_update):

        amphora_get_info_obj = amphora_driver_tasks.AmphoraGetInfo()
        amphora_get_info_obj.execute(_amphora_mock)

        mock_driver.get_info.assert_called_once_with(
            _amphora_mock)

    def test_amphora_get_diagnostics(self,
                                     mock_driver,
                                     mock_generate_uuid,
                                     mock_log,
                                     mock_get_session,
                                     mock_listener_repo_get,
                                     mock_listener_repo_update,
                                     mock_amphora_repo_update):

        amphora_get_diagnostics_obj = (amphora_driver_tasks.
                                       AmphoraGetDiagnostics())
        amphora_get_diagnostics_obj.execute(_amphora_mock)

        mock_driver.get_diagnostics.assert_called_once_with(
            _amphora_mock)

    def test_amphora_finalize(self,
                              mock_driver,
                              mock_generate_uuid,
                              mock_log,
                              mock_get_session,
                              mock_listener_repo_get,
                              mock_listener_repo_update,
                              mock_amphora_repo_update):

        amphora_finalize_obj = amphora_driver_tasks.AmphoraFinalize()
        amphora_finalize_obj.execute(_amphora_mock)

        mock_driver.finalize_amphora.assert_called_once_with(
            _amphora_mock)

        # Test revert
        amp = amphora_finalize_obj.revert(None, _amphora_mock)
        repo.AmphoraRepository.update.assert_called_once_with(
            _session_mock,
            id=AMP_ID,
            status=constants.ERROR)
        self.assertIsNone(amp)

        # Test revert with exception
        repo.AmphoraRepository.update.reset_mock()
        mock_amphora_repo_update.side_effect = Exception('fail')
        amp = amphora_finalize_obj.revert(None, _amphora_mock)
        repo.AmphoraRepository.update.assert_called_once_with(
            _session_mock,
            id=AMP_ID,
            status=constants.ERROR)
        self.assertIsNone(amp)

        # Test revert when this task failed
        repo.AmphoraRepository.update.reset_mock()
        amp = amphora_finalize_obj.revert(
            failure.Failure.from_exception(Exception('boom')), _amphora_mock)
        repo.AmphoraRepository.update.assert_not_called()

    def test_amphora_post_network_plug(self,
                                       mock_driver,
                                       mock_generate_uuid,
                                       mock_log,
                                       mock_get_session,
                                       mock_listener_repo_get,
                                       mock_listener_repo_update,
                                       mock_amphora_repo_update):

        amphora_post_network_plug_obj = (amphora_driver_tasks.
                                         AmphoraPostNetworkPlug())
        amphora_post_network_plug_obj.execute(_amphora_mock, _ports_mock,
                                              _amphora_network_config_mock)

        (mock_driver.post_network_plug.
            assert_called_once_with)(_amphora_mock, _port_mock,
                                     _amphora_network_config_mock)

        # Test revert
        amp = amphora_post_network_plug_obj.revert(None, _amphora_mock)
        repo.AmphoraRepository.update.assert_called_once_with(
            _session_mock,
            id=AMP_ID,
            status=constants.ERROR)

        self.assertIsNone(amp)

        # Test revert with exception
        repo.AmphoraRepository.update.reset_mock()
        mock_amphora_repo_update.side_effect = Exception('fail')
        amp = amphora_post_network_plug_obj.revert(None, _amphora_mock)
        repo.AmphoraRepository.update.assert_called_once_with(
            _session_mock,
            id=AMP_ID,
            status=constants.ERROR)

        self.assertIsNone(amp)

        # Test revert when this task failed
        repo.AmphoraRepository.update.reset_mock()
        amp = amphora_post_network_plug_obj.revert(
            failure.Failure.from_exception(Exception('boom')), _amphora_mock)
        repo.AmphoraRepository.update.assert_not_called()

    @mock.patch('octavia.db.repositories.AmphoraRepository.get_all')
    def test_amphorae_post_network_plug(self, mock_amp_get_all, mock_driver,
                                        mock_generate_uuid,
                                        mock_log,
                                        mock_get_session,
                                        mock_listener_repo_get,
                                        mock_listener_repo_update,
                                        mock_amphora_repo_update):
        mock_driver.get_network.return_value = _network_mock
        _amphora_mock.id = AMP_ID
        _amphora_mock.compute_id = COMPUTE_ID
        mock_amp_get_all.return_value = [[_amphora_mock], None]
        amphora_post_network_plug_obj = (amphora_driver_tasks.
                                         AmphoraePostNetworkPlug())

        port_mock = mock.Mock()
        _deltas_mock = {_amphora_mock.id: [port_mock]}

        amphora_post_network_plug_obj.execute(_LB_mock, _deltas_mock,
                                              _amphorae_network_config_mock)

        (mock_driver.post_network_plug.
            assert_called_once_with(_amphora_mock, port_mock,
                                    _amphora_network_config_mock))

        # Test with no ports to plug
        mock_driver.post_network_plug.reset_mock()

        _deltas_mock = {'0': [port_mock]}

        amphora_post_network_plug_obj.execute(_LB_mock, _deltas_mock,
                                              _amphora_network_config_mock)
        mock_driver.post_network_plug.assert_not_called()

        # Test revert
        amp = amphora_post_network_plug_obj.revert(None, _LB_mock,
                                                   _deltas_mock)
        repo.AmphoraRepository.update.assert_called_once_with(
            _session_mock,
            id=AMP_ID,
            status=constants.ERROR)

        self.assertIsNone(amp)

        # Test revert with exception
        repo.AmphoraRepository.update.reset_mock()
        mock_amphora_repo_update.side_effect = Exception('fail')
        amp = amphora_post_network_plug_obj.revert(None, _LB_mock,
                                                   _deltas_mock)
        repo.AmphoraRepository.update.assert_called_once_with(
            _session_mock,
            id=AMP_ID,
            status=constants.ERROR)

        self.assertIsNone(amp)

        # Test revert when this task failed
        repo.AmphoraRepository.update.reset_mock()
        amp = amphora_post_network_plug_obj.revert(
            failure.Failure.from_exception(Exception('boom')), _amphora_mock,
            None)
        repo.AmphoraRepository.update.assert_not_called()

    @mock.patch('octavia.db.repositories.LoadBalancerRepository.update')
    def test_amphora_post_vip_plug(self,
                                   mock_loadbalancer_repo_update,
                                   mock_driver,
                                   mock_generate_uuid,
                                   mock_log,
                                   mock_get_session,
                                   mock_listener_repo_get,
                                   mock_listener_repo_update,
                                   mock_amphora_repo_update):

        amphorae_net_config_mock = mock.Mock()
        amphora_post_vip_plug_obj = amphora_driver_tasks.AmphoraPostVIPPlug()
        amphora_post_vip_plug_obj.execute(_amphora_mock,
                                          _LB_mock,
                                          amphorae_net_config_mock)

        mock_driver.post_vip_plug.assert_called_once_with(
            _amphora_mock, _LB_mock, amphorae_net_config_mock)

        # Test revert
        amp = amphora_post_vip_plug_obj.revert(None, _amphora_mock, _LB_mock)
        repo.AmphoraRepository.update.assert_called_once_with(
            _session_mock,
            id=AMP_ID,
            status=constants.ERROR)
        repo.LoadBalancerRepository.update.assert_not_called()

        self.assertIsNone(amp)

        # Test revert with repo exceptions
        repo.AmphoraRepository.update.reset_mock()
        repo.LoadBalancerRepository.update.reset_mock()
        mock_amphora_repo_update.side_effect = Exception('fail')
        mock_loadbalancer_repo_update.side_effect = Exception('fail')
        amp = amphora_post_vip_plug_obj.revert(None, _amphora_mock, _LB_mock)
        repo.AmphoraRepository.update.assert_called_once_with(
            _session_mock,
            id=AMP_ID,
            status=constants.ERROR)
        repo.LoadBalancerRepository.update.assert_not_called()

        self.assertIsNone(amp)

        # Test revert when this task failed
        repo.AmphoraRepository.update.reset_mock()
        amp = amphora_post_vip_plug_obj.revert(
            failure.Failure.from_exception(Exception('boom')), _amphora_mock,
            None)
        repo.AmphoraRepository.update.assert_not_called()

    @mock.patch('octavia.db.repositories.LoadBalancerRepository.update')
    def test_amphorae_post_vip_plug(self,
                                    mock_loadbalancer_repo_update,
                                    mock_driver,
                                    mock_generate_uuid,
                                    mock_log,
                                    mock_get_session,
                                    mock_listener_repo_get,
                                    mock_listener_repo_update,
                                    mock_amphora_repo_update):

        amphorae_net_config_mock = mock.Mock()
        amphora_post_vip_plug_obj = amphora_driver_tasks.AmphoraePostVIPPlug()
        amphora_post_vip_plug_obj.execute(_LB_mock,
                                          amphorae_net_config_mock)

        mock_driver.post_vip_plug.assert_called_once_with(
            _amphora_mock, _LB_mock, amphorae_net_config_mock)

    def test_amphora_cert_upload(self,
                                 mock_driver,
                                 mock_generate_uuid,
                                 mock_log,
                                 mock_get_session,
                                 mock_listener_repo_get,
                                 mock_listener_repo_update,
                                 mock_amphora_repo_update):
        key = utils.get_compatible_server_certs_key_passphrase()
        fer = fernet.Fernet(key)
        pem_file_mock = fer.encrypt(
            utils.get_compatible_value('test-pem-file'))
        amphora_cert_upload_mock = amphora_driver_tasks.AmphoraCertUpload()
        amphora_cert_upload_mock.execute(_amphora_mock, pem_file_mock)

        mock_driver.upload_cert_amp.assert_called_once_with(
            _amphora_mock, fer.decrypt(pem_file_mock))

    def test_amphora_update_vrrp_interface(self,
                                           mock_driver,
                                           mock_generate_uuid,
                                           mock_log,
                                           mock_get_session,
                                           mock_listener_repo_get,
                                           mock_listener_repo_update,
                                           mock_amphora_repo_update):
        FAKE_INTERFACE = 'fake0'
        _LB_mock.amphorae = _amphorae_mock
        mock_driver.get_interface_from_ip.side_effect = [FAKE_INTERFACE,
                                                         Exception('boom')]

        timeout_dict = {constants.CONN_MAX_RETRIES: CONN_MAX_RETRIES,
                        constants.CONN_RETRY_INTERVAL: CONN_RETRY_INTERVAL}

        amphora_update_vrrp_interface_obj = (
            amphora_driver_tasks.AmphoraUpdateVRRPInterface())
        amphora_update_vrrp_interface_obj.execute(_amphora_mock, timeout_dict)
        mock_driver.get_interface_from_ip.assert_called_once_with(
            _amphora_mock, _amphora_mock.vrrp_ip, timeout_dict=timeout_dict)
        mock_amphora_repo_update.assert_called_once_with(
            _session_mock, _amphora_mock.id, vrrp_interface=FAKE_INTERFACE)

        # Test with an exception
        mock_amphora_repo_update.reset_mock()
        amphora_update_vrrp_interface_obj.execute(_amphora_mock, timeout_dict)
        mock_amphora_repo_update.assert_called_once_with(
            _session_mock, _amphora_mock.id, status=constants.ERROR)

    def test_amphora_index_update_vrrp_interface(
            self, mock_driver, mock_generate_uuid, mock_log, mock_get_session,
            mock_listener_repo_get, mock_listener_repo_update,
            mock_amphora_repo_update):
        FAKE_INTERFACE = 'fake0'
        _LB_mock.amphorae = _amphorae_mock
        mock_driver.get_interface_from_ip.side_effect = [FAKE_INTERFACE,
                                                         Exception('boom')]
        amphorae_status = {
            _amphora_mock.id: {
                constants.UNREACHABLE: False
            }
        }

        timeout_dict = {constants.CONN_MAX_RETRIES: CONN_MAX_RETRIES,
                        constants.CONN_RETRY_INTERVAL: CONN_RETRY_INTERVAL}

        amphora_update_vrrp_interface_obj = (
            amphora_driver_tasks.AmphoraIndexUpdateVRRPInterface())
        amphora_update_vrrp_interface_obj.execute(
            0, [_amphora_mock], amphorae_status, _amphora_mock.id,
            timeout_dict)
        mock_driver.get_interface_from_ip.assert_called_once_with(
            _amphora_mock, _amphora_mock.vrrp_ip, timeout_dict=timeout_dict)
        mock_amphora_repo_update.assert_called_once_with(
            _session_mock, _amphora_mock.id, vrrp_interface=FAKE_INTERFACE)

        # Unreachable amp
        mock_driver.reset_mock()
        amphorae_status = {
            _amphora_mock.id: {
                constants.UNREACHABLE: True
            }
        }
        amphora_update_vrrp_interface_obj.execute(
            0, [_amphora_mock], amphorae_status, _amphora_mock.id,
            timeout_dict)
        mock_driver.get_interface_from_ip.assert_not_called()

        # Test with an exception
        mock_amphora_repo_update.reset_mock()
        amphora_update_vrrp_interface_obj.execute(
            0, [_amphora_mock], {}, _amphora_mock.id, timeout_dict)
        mock_amphora_repo_update.assert_called_once_with(
            _session_mock, _amphora_mock.id, status=constants.ERROR)

        # Test with an exception, secondary amp
        mock_amphora_repo_update.reset_mock()
        amphora_update_vrrp_interface_obj.execute(
            0, [_amphora_mock], {}, '1234', timeout_dict)
        mock_amphora_repo_update.assert_not_called()

    @mock.patch('octavia.db.repositories.LoadBalancerRepository.get')
    def test_amphora_vrrp_update(self,
                                 mock_lb_get,
                                 mock_driver,
                                 mock_generate_uuid,
                                 mock_log,
                                 mock_get_session,
                                 mock_listener_repo_get,
                                 mock_listener_repo_update,
                                 mock_amphora_repo_update):
        amphorae_network_config = mock.MagicMock()
        mock_driver.update_vrrp_conf.side_effect = [mock.DEFAULT,
                                                    Exception('boom')]
        mock_lb_get.return_value = _LB_mock
        amphora_vrrp_update_obj = (
            amphora_driver_tasks.AmphoraVRRPUpdate())
        amphora_vrrp_update_obj.execute(_LB_mock.id, amphorae_network_config,
                                        _amphora_mock, 'fakeint0')
        mock_driver.update_vrrp_conf.assert_called_once_with(
            _LB_mock, amphorae_network_config, _amphora_mock, None)

        # Test with an exception
        mock_amphora_repo_update.reset_mock()
        amphora_vrrp_update_obj.execute(_LB_mock.id, amphorae_network_config,
                                        _amphora_mock, 'fakeint0')
        mock_amphora_repo_update.assert_called_once_with(
            _session_mock, _amphora_mock.id, status=constants.ERROR)

    @mock.patch('octavia.db.repositories.LoadBalancerRepository.get')
    def test_amphora_index_vrrp_update(self,
                                       mock_lb_get,
                                       mock_driver,
                                       mock_generate_uuid,
                                       mock_log,
                                       mock_get_session,
                                       mock_listener_repo_get,
                                       mock_listener_repo_update,
                                       mock_amphora_repo_update):
        amphorae_network_config = mock.MagicMock()
        mock_driver.update_vrrp_conf.side_effect = [mock.DEFAULT,
                                                    Exception('boom')]
        mock_lb_get.return_value = _LB_mock
        amphorae_status = {
            _amphora_mock.id: {
                constants.UNREACHABLE: False
            }
        }

        amphora_vrrp_update_obj = (
            amphora_driver_tasks.AmphoraIndexVRRPUpdate())

        amphora_vrrp_update_obj.execute(_LB_mock.id, amphorae_network_config,
                                        0, [_amphora_mock], amphorae_status,
                                        'fakeint0',
                                        _amphora_mock.id,
                                        timeout_dict=self.timeout_dict)
        mock_driver.update_vrrp_conf.assert_called_once_with(
            _LB_mock, amphorae_network_config, _amphora_mock,
            self.timeout_dict)

        # Unreachable amp
        amphorae_status = {
            _amphora_mock.id: {
                constants.UNREACHABLE: True
            }
        }
        mock_amphora_repo_update.reset_mock()
        mock_driver.update_vrrp_conf.reset_mock()
        amphora_vrrp_update_obj.execute(LB_ID, amphorae_network_config,
                                        0, [_amphora_mock], amphorae_status,
                                        None, _amphora_mock.id)
        mock_driver.update_vrrp_conf.assert_not_called()

        # Test with an exception
        mock_amphora_repo_update.reset_mock()
        amphora_vrrp_update_obj.execute(_LB_mock.id, amphorae_network_config,
                                        0, [_amphora_mock], {}, 'fakeint0',
                                        _amphora_mock.id)
        mock_amphora_repo_update.assert_called_once_with(
            _session_mock, _amphora_mock.id, status=constants.ERROR)

        # Test with an exception, secondary amp
        mock_amphora_repo_update.reset_mock()
        amphora_vrrp_update_obj.execute(LB_ID, amphorae_network_config,
                                        0, [_amphora_mock], {}, 'fakeint0',
                                        '1234')
        mock_amphora_repo_update.assert_not_called()

    def test_amphora_vrrp_start(self,
                                mock_driver,
                                mock_generate_uuid,
                                mock_log,
                                mock_get_session,
                                mock_listener_repo_get,
                                mock_listener_repo_update,
                                mock_amphora_repo_update):
        amphora_vrrp_start_obj = (
            amphora_driver_tasks.AmphoraVRRPStart())
        amphora_vrrp_start_obj.execute(_amphora_mock,
                                       timeout_dict=self.timeout_dict)
        mock_driver.start_vrrp_service.assert_called_once_with(
            _amphora_mock, self.timeout_dict)

    def test_amphora_index_vrrp_start(self,
                                      mock_driver,
                                      mock_generate_uuid,
                                      mock_log,
                                      mock_get_session,
                                      mock_listener_repo_get,
                                      mock_listener_repo_update,
                                      mock_amphora_repo_update):
        amphorae_status = {
            _amphora_mock.id: {
                constants.UNREACHABLE: False
            }
        }

        amphora_vrrp_start_obj = (
            amphora_driver_tasks.AmphoraIndexVRRPStart())
        mock_driver.start_vrrp_service.side_effect = [mock.DEFAULT,
                                                      Exception('boom')]

        amphora_vrrp_start_obj.execute(0, [_amphora_mock], amphorae_status,
                                       _amphora_mock.id,
                                       timeout_dict=self.timeout_dict)
        mock_driver.start_vrrp_service.assert_called_once_with(
            _amphora_mock, self.timeout_dict)

        # Unreachable amp
        mock_driver.start_vrrp_service.reset_mock()
        amphorae_status = {
            _amphora_mock.id: {
                constants.UNREACHABLE: True
            }
        }
        amphora_vrrp_start_obj.execute(0, [_amphora_mock], amphorae_status,
                                       _amphora_mock.id,
                                       timeout_dict=self.timeout_dict)
        mock_driver.start_vrrp_service.assert_not_called()

        # Test with a start exception
        mock_driver.start_vrrp_service.reset_mock()
        amphora_vrrp_start_obj.execute(0, [_amphora_mock], {},
                                       _amphora_mock.id,
                                       timeout_dict=self.timeout_dict)
        mock_driver.start_vrrp_service.assert_called_once_with(
            _amphora_mock, self.timeout_dict)
        mock_amphora_repo_update.assert_called_once_with(
            _session_mock, _amphora_mock.id, status=constants.ERROR)

        # Test with a start exception, secondary amp
        mock_driver.start_vrrp_service.reset_mock()
        mock_amphora_repo_update.reset_mock()
        amphora_vrrp_start_obj.execute(0, [_amphora_mock], {}, '1234',
                                       timeout_dict=self.timeout_dict)
        mock_driver.start_vrrp_service.assert_called_once_with(
            _amphora_mock, self.timeout_dict)
        mock_amphora_repo_update.assert_not_called()

    def test_amphora_compute_connectivity_wait(self,
                                               mock_driver,
                                               mock_generate_uuid,
                                               mock_log,
                                               mock_get_session,
                                               mock_listener_repo_get,
                                               mock_listener_repo_update,
                                               mock_amphora_repo_update):
        amp_compute_conn_wait_obj = (
            amphora_driver_tasks.AmphoraComputeConnectivityWait())
        amp_compute_conn_wait_obj.execute(_amphora_mock)
        mock_driver.get_info.assert_called_once_with(_amphora_mock)

        mock_driver.get_info.side_effect = driver_except.TimeOutException()
        self.assertRaises(driver_except.TimeOutException,
                          amp_compute_conn_wait_obj.execute, _amphora_mock)
        mock_amphora_repo_update.assert_called_once_with(
            _session_mock, AMP_ID, status=constants.ERROR)

    @mock.patch('octavia.amphorae.backends.agent.agent_jinja_cfg.'
                'AgentJinjaTemplater.build_agent_config')
    def test_amphora_config_update(self,
                                   mock_build_config,
                                   mock_driver,
                                   mock_generate_uuid,
                                   mock_log,
                                   mock_get_session,
                                   mock_listener_repo_get,
                                   mock_listener_repo_update,
                                   mock_amphora_repo_update):
        mock_build_config.return_value = FAKE_CONFIG_FILE
        amp_config_update_obj = amphora_driver_tasks.AmphoraConfigUpdate()
        mock_driver.update_amphora_agent_config.side_effect = [
            None, None, driver_except.AmpDriverNotImplementedError,
            driver_except.TimeOutException]
        # With Flavor
        flavor = {constants.LOADBALANCER_TOPOLOGY:
                  constants.TOPOLOGY_ACTIVE_STANDBY}
        amp_config_update_obj.execute(_amphora_mock, flavor)
        mock_build_config.assert_called_once_with(
            _amphora_mock.id, constants.TOPOLOGY_ACTIVE_STANDBY)
        mock_driver.update_amphora_agent_config.assert_called_once_with(
            _amphora_mock, FAKE_CONFIG_FILE)
        # With no Flavor
        mock_driver.reset_mock()
        mock_build_config.reset_mock()
        amp_config_update_obj.execute(_amphora_mock, None)
        mock_build_config.assert_called_once_with(
            _amphora_mock.id, constants.TOPOLOGY_SINGLE)
        mock_driver.update_amphora_agent_config.assert_called_once_with(
            _amphora_mock, FAKE_CONFIG_FILE)
        # With amphora that does not support config update
        mock_driver.reset_mock()
        mock_build_config.reset_mock()
        amp_config_update_obj.execute(_amphora_mock, flavor)
        mock_build_config.assert_called_once_with(
            _amphora_mock.id, constants.TOPOLOGY_ACTIVE_STANDBY)
        mock_driver.update_amphora_agent_config.assert_called_once_with(
            _amphora_mock, FAKE_CONFIG_FILE)
        # With an unknown exception
        mock_driver.reset_mock()
        mock_build_config.reset_mock()
        self.assertRaises(driver_except.TimeOutException,
                          amp_config_update_obj.execute,
                          _amphora_mock, flavor)

    @mock.patch('octavia.db.repositories.AmphoraRepository.get')
    def test_amphorae_get_connectivity_status(self,
                                              mock_amphora_repo_get,
                                              mock_driver,
                                              mock_generate_uuid,
                                              mock_log,
                                              mock_get_session,
                                              mock_listener_repo_get,
                                              mock_listener_repo_update,
                                              mock_amphora_repo_update):
        amphora1_mock = mock.MagicMock()
        amphora1_mock.id = 'id1'
        amphora2_mock = mock.MagicMock()
        amphora2_mock.id = 'id2'
        db_amphora1_mock = mock.Mock()
        db_amphora2_mock = mock.Mock()

        amp_get_connectivity_status = (
            amphora_driver_tasks.AmphoraeGetConnectivityStatus())

        # All amphorae reachable
        mock_amphora_repo_get.side_effect = [
            db_amphora1_mock,
            db_amphora2_mock]
        mock_driver.check.return_value = None

        ret = amp_get_connectivity_status.execute(
            [amphora1_mock, amphora2_mock],
            amphora1_mock.id,
            timeout_dict=self.timeout_dict)
        mock_driver.check.assert_has_calls(
            [mock.call(db_amphora1_mock, timeout_dict=self.timeout_dict),
             mock.call(db_amphora2_mock, timeout_dict=self.timeout_dict)])
        self.assertFalse(
            ret[amphora1_mock.id][constants.UNREACHABLE])
        self.assertFalse(
            ret[amphora2_mock.id][constants.UNREACHABLE])

        # amphora1 unreachable
        mock_driver.check.reset_mock()
        mock_amphora_repo_get.side_effect = [
            db_amphora1_mock,
            db_amphora2_mock]
        mock_driver.check.side_effect = [
            driver_except.TimeOutException, None]
        self.assertRaises(driver_except.TimeOutException,
                          amp_get_connectivity_status.execute,
                          [amphora1_mock, amphora2_mock],
                          amphora1_mock.id,
                          timeout_dict=self.timeout_dict)
        mock_driver.check.assert_called_with(
            db_amphora1_mock, timeout_dict=self.timeout_dict)

        # amphora2 unreachable
        mock_driver.check.reset_mock()
        mock_amphora_repo_get.side_effect = [
            db_amphora1_mock,
            db_amphora2_mock]
        mock_driver.check.side_effect = [
            None, driver_except.TimeOutException]
        ret = amp_get_connectivity_status.execute(
            [amphora1_mock, amphora2_mock],
            amphora1_mock.id,
            timeout_dict=self.timeout_dict)
        mock_driver.check.assert_has_calls(
            [mock.call(db_amphora1_mock, timeout_dict=self.timeout_dict),
             mock.call(db_amphora2_mock, timeout_dict=self.timeout_dict)])
        self.assertFalse(
            ret[amphora1_mock.id][constants.UNREACHABLE])
        self.assertTrue(
            ret[amphora2_mock.id][constants.UNREACHABLE])
