/* ****************************************************************************
 * Copyright 2019 Open Systems Development BV                                 *
 *                                                                            *
 * Permission is hereby granted, free of charge, to any person obtaining a    *
 * copy of this software and associated documentation files (the "Software"), *
 * to deal in the Software without restriction, including without limitation  *
 * the rights to use, copy, modify, merge, publish, distribute, sublicense,   *
 * and/or sell copies of the Software, and to permit persons to whom the      *
 * Software is furnished to do so, subject to the following conditions:       *
 *                                                                            *
 * The above copyright notice and this permission notice shall be included in *
 * all copies or substantial portions of the Software.                        *
 *                                                                            *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR *
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,   *
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL    *
 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER *
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING    *
 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER        *
 * DEALINGS IN THE SOFTWARE.                                                  *
 * ***************************************************************************/
#include "mqttclient.h"

// osdev::components::mqtt
#include "clientpaho.h"
#include "mqttutil.h"
#include "mqttidgenerator.h"
#include "mqtttypeconverter.h"
#include "lockguard.h"
#include "uriparser.h"

// std
#include <numeric>
#include <iostream>
#include <string>

using namespace osdev::components::mqtt;
using namespace osdev::components::log;

namespace {
/**
 * @brief Generate a unique client id so that a new client does not steal the connection from an existing client.
 * @param clientId The base client identifier.
 * @param clientNumber The next client that is derived from the base client identifier.
 * @return A unique client identifier string.
 */
std::string generateUniqueClientId(const std::string& clientId, std::size_t clientNumber)
{
    return clientId + "_" + std::to_string(clientNumber) + "_" + MqttTypeConverter::toStdString(MqttIdGenerator::generate());
}

} // namespace

MqttClient::MqttClient(const std::string& _clientId, const std::function<void(const Token& token)>& deliveryCompleteCallback)
    : m_interfaceMutex()
    , m_internalMutex()
    , m_subscriptionMutex()
    , m_endpoint()
    , m_clientId(_clientId)
    , m_activeTokens()
    , m_activeTokensCV()
    , m_deliveryCompleteCallback(deliveryCompleteCallback)
    , m_serverState(this)
    , m_principalClient()
    , m_additionalClients()
    , m_eventQueue(_clientId)
    , m_workerThread( std::thread( &MqttClient::eventHandler, this ) )
    , m_deferredSubscriptions()
{
    Log::init( "mqtt-library" );
    LogInfo( "MQTT Client started", "[MqttClient::MqttClient]");
}

MqttClient::~MqttClient()
{
    {
        // LogDebug( "MqttClient", std::string( m_clientId + " - disconnect" ) );
        this->disconnect();
        decltype(m_principalClient) principalClient{};

        OSDEV_COMPONENTS_LOCKGUARD(m_internalMutex);
        LogDebug( "MqttClient", std::string( m_clientId + " - cleanup principal client" ) );
        m_principalClient.swap(principalClient);
    }

    LogDebug( "MqttClient", std::string( m_clientId + " - dtor stop queue" ) );
    m_eventQueue.stop();
    if (m_workerThread.joinable()) {
        m_workerThread.join();
    }
    LogDebug( "MqttClient", std::string( m_clientId + " - dtor ready" ) );
}

std::string MqttClient::clientId() const
{
    return m_clientId;
}

StateChangeCallbackHandle MqttClient::registerStateChangeCallback(const SlotStateChange& cb)
{
    OSDEV_COMPONENTS_LOCKGUARD(m_interfaceMutex);
    OSDEV_COMPONENTS_LOCKGUARD(m_internalMutex);
    return m_serverState.registerStateChangeCallback(cb);
}

void MqttClient::unregisterStateChangeCallback(StateChangeCallbackHandle handle)
{
    OSDEV_COMPONENTS_LOCKGUARD(m_interfaceMutex);
    OSDEV_COMPONENTS_LOCKGUARD(m_internalMutex);
    m_serverState.unregisterStateChangeCallback(handle);
}

void MqttClient::clearAllStateChangeCallbacks()
{
    OSDEV_COMPONENTS_LOCKGUARD(m_interfaceMutex);
    OSDEV_COMPONENTS_LOCKGUARD(m_internalMutex);
    m_serverState.clearAllStateChangeCallbacks();
}

StateEnum MqttClient::state() const
{
    OSDEV_COMPONENTS_LOCKGUARD(m_interfaceMutex);
    OSDEV_COMPONENTS_LOCKGUARD(m_internalMutex);
    return m_serverState.state();
}

void MqttClient::connect(const std::string& host, int port, const Credentials &credentials, const mqtt_LWT &lwt, bool blocking, const LogSettings &log_settings )
{
    osdev::components::mqtt::ParsedUri _endpoint = {
        { "scheme", "tcp" },
        { "user", credentials.username() },
        { "password", credentials.password() },
        { "host", host },
        { "port", std::to_string(port) }
    };

    this->connect( UriParser::toString( _endpoint ), lwt, blocking, log_settings );
}

void MqttClient::connect( const std::string &_endpoint, const mqtt_LWT &lwt, bool blocking, const LogSettings &log_settings )
{
    Log::setLogLevel( log_settings.level );
    Log::setMask( log_settings.mask );

    LogInfo( "MqttClient", std::string( m_clientId + " - Request connect" ) );

    OSDEV_COMPONENTS_LOCKGUARD(m_interfaceMutex);
    IMqttClientImpl* client(nullptr);
    {
        OSDEV_COMPONENTS_LOCKGUARD(m_internalMutex);
        if (m_principalClient && m_principalClient->connectionStatus() != ConnectionStatus::Disconnected) {
            if (_endpoint == m_endpoint)
            {
                // idempotent
                return;
            }
            else
            {
                LogError( "MqttClient", std::string( m_clientId + " - Cannot connect to different endpoint. Disconnect first." ) );
                return;
            }
        }
        m_endpoint = _endpoint;
        if (!m_principalClient)
        {
            std::string derivedClientId(generateUniqueClientId(m_clientId, 1));
            m_principalClient = std::make_unique<ClientPaho>(
                m_endpoint,
                derivedClientId,
                [this](const std::string& id, ConnectionStatus cs) { this->connectionStatusChanged(id, cs); },
                [this](std::string clientId, std::int32_t token) { this->deliveryComplete(clientId, token); });
        }
        client = m_principalClient.get();
    }

    client->connect( blocking, lwt );
}

void MqttClient::disconnect()
{
    LogInfo( "MqttClient", std::string( m_clientId + " - Request disconnect" ) );
    OSDEV_COMPONENTS_LOCKGUARD(m_interfaceMutex);

    decltype(m_additionalClients) additionalClients{};
    std::vector<IMqttClientImpl*> clients{};
    {
        OSDEV_COMPONENTS_LOCKGUARD(m_internalMutex);
        if (!m_principalClient || m_principalClient->connectionStatus() == ConnectionStatus::Disconnected || m_principalClient->connectionStatus() == ConnectionStatus::DisconnectInProgress)
        {
            LogDebug( "MqttClient", std::string( m_clientId + " - Principal client not connected" ) );
            return;
        }
        m_additionalClients.swap( additionalClients );

        for (const auto& c : additionalClients)
        {
            clients.push_back( c.get() );
        }
        clients.push_back( m_principalClient.get() );
    }


    LogDebug( "MqttClient", std::string( m_clientId +  " - Unsubscribe and disconnect clients" ) );
    for ( auto& cl : clients )
    {
        cl->unsubscribeAll();
    }
    this->waitForCompletionInternal(clients, std::chrono::milliseconds(2000), std::set<Token>{});

    for (auto& cl : clients) {
        cl->disconnect(false, 2000);
    }
    this->waitForCompletionInternal(clients, std::chrono::milliseconds(3000), std::set<Token>{});
}

Token MqttClient::publish(const MqttMessage& message, int qos)
{
    if (hasWildcard(message.topic()))
    {
        LogDebug("[MqttClient::publish]","Topic has wildcard : " + message.topic());
        return Token(m_clientId, -1);
    }
    else if(!isValidTopic(message.topic()))
    {
        LogDebug("[MqttClient::publish]","Topic is invalid : " + message.topic());
        return Token(m_clientId, -1);
    }

    OSDEV_COMPONENTS_LOCKGUARD(m_interfaceMutex);
    IMqttClientImpl* client(nullptr);
    {
        OSDEV_COMPONENTS_LOCKGUARD(m_internalMutex);
        if (!m_principalClient || m_principalClient->connectionStatus() == ConnectionStatus::Disconnected)
        {
            if( !m_principalClient )
            {
                LogInfo( "[MqttClient::publish]", std::string( "Principal client not initialized") );
            }

            if( m_principalClient->connectionStatus() == ConnectionStatus::Disconnected )
            {
                std::cout << "Unable to publish, not connected.." << std::endl;
            }
            LogError("MqttClient", std::string( m_clientId + " - Unable to publish, not connected" ) );

            return Token(m_clientId, -1);
        }
        client = m_principalClient.get();
    }

    if(!client)
    {
        LogDebug("[MqttClient::publish]", "Invalid pointer to IMqttClient retrieved.");
        return Token(m_clientId, -1);
    }

    return Token(client->clientId(), client->publish(message, qos));
}

Token MqttClient::subscribe(const std::string& topic, int qos, const std::function<void(MqttMessage)>& cb)
{
    LogDebug( "[MqttClient::subscribe]", std::string( m_clientId + " - Subscribe to topic " + topic ) );
    // OSDEV_COMPONENTS_LOCKGUARD(m_interfaceMutex);
    bool clientFound = false;
    IMqttClientImpl* client(nullptr);
    {
        // OSDEV_COMPONENTS_LOCKGUARD(m_internalMutex);
        if (!m_principalClient || m_principalClient->connectionStatus() != ConnectionStatus::Connected)
        {
            LogError("MqttClient", std::string( m_clientId + " - Unable to subscribe, not connected" ) );
            // Store the subscription in the buffer for later processing.
            {
                OSDEV_COMPONENTS_LOCKGUARD(m_subscriptionMutex);
                m_deferredSubscriptions.emplace_back( topic, qos, cb );
            }

            return Token(m_clientId, -1);
        }

        if (!m_principalClient->isOverlapping(topic))
        {
            client = m_principalClient.get();
            clientFound = true;
        }
        else
        {
            for (const auto& c : m_additionalClients)
            {
                if (!c->isOverlapping(topic))
                {
                    client = c.get();
                    clientFound = true;
                    break;
                }
            }
        }

        if (!clientFound)
        {
            LogDebug("[MqttClient::subscribe]", std::string( m_clientId + " - Creating new ClientPaho instance for subscription on topic " + topic ) );
            std::string derivedClientId(generateUniqueClientId(m_clientId, m_additionalClients.size() + 2)); // principal client is nr 1.
            m_additionalClients.emplace_back(std::make_unique<ClientPaho>(
                m_endpoint,
                derivedClientId,
                [this](const std::string& id, ConnectionStatus cs) { this->connectionStatusChanged(id, cs); },
                std::function<void(const std::string&, std::int32_t)>{}));
            client = m_additionalClients.back().get();
        }
    }

    if (!clientFound)
    {
        client->connect( true );
    }
    return Token{ client->clientId(), client->subscribe(topic, qos, cb) };
}

std::set<Token> MqttClient::unsubscribe(const std::string& topic, int qos)
{
    LogDebug("[MqttClient::unsubscribe]", std::string( m_clientId + " - Unsubscribe from topic " + topic ) );
    OSDEV_COMPONENTS_LOCKGUARD(m_interfaceMutex);
    std::vector<IMqttClientImpl*> clients{};
    {
        OSDEV_COMPONENTS_LOCKGUARD(m_internalMutex);
        if (!m_principalClient || m_principalClient->connectionStatus() == ConnectionStatus::Disconnected)
        {
            LogError("[MqttClient::unsubscribe]", std::string( m_clientId + " - Unable to unsubscribe, not connected" ) );
            // Throw (MqttException, "Not connected");
            return std::set<Token>();
        }

        clients.push_back(m_principalClient.get());
        for (const auto& c : m_additionalClients)
        {
            clients.push_back(c.get());
        }
    }
    std::set<Token> tokens{};
    for (auto* c : clients) {
        auto token = c->unsubscribe(topic, qos);
        if (-1 != token) {
            tokens.emplace(Token{ c->clientId(), token });
        }
    }
    return tokens;
}

bool MqttClient::waitForCompletion(std::chrono::milliseconds waitFor) const
{
    return this->waitForCompletion(waitFor, std::set<Token>{});
}

bool MqttClient::waitForCompletion(std::chrono::milliseconds waitFor, const Token& token) const
{
    if (-1 == token.token()) {
        return false;
    }
    return this->waitForCompletion(waitFor, std::set<Token>{ token });
}

bool MqttClient::waitForCompletion(std::chrono::milliseconds waitFor, const std::set<Token>& tokens) const
{
    OSDEV_COMPONENTS_LOCKGUARD(m_interfaceMutex);
    std::vector<IMqttClientImpl*> clients{};
    {
        OSDEV_COMPONENTS_LOCKGUARD(m_internalMutex);
        if (m_principalClient) {
            clients.push_back(m_principalClient.get());
        }
        for (const auto& c : m_additionalClients) {
            clients.push_back(c.get());
        }
    }
    return waitForCompletionInternal(clients, waitFor, tokens);
}

boost::optional<bool> MqttClient::commandResult(const Token& token) const
{
    OSDEV_COMPONENTS_LOCKGUARD(m_interfaceMutex);
    std::vector<IMqttClientImpl*> clients{};
    {
        OSDEV_COMPONENTS_LOCKGUARD(m_internalMutex);
        if (m_principalClient) {
            clients.push_back(m_principalClient.get());
        }
        for (const auto& c : m_additionalClients) {
            clients.push_back(c.get());
        }
    }
    for (auto* c : clients) {
        if (token.clientId() == c->clientId()) {
            return c->operationResult(token.token());
        }
    }
    return boost::optional<bool>{};
}

std::string MqttClient::endpoint() const
{
    auto ep = UriParser::parse(m_endpoint);
    if (ep.find("user") != ep.end()) {
        ep["user"].clear();
    }
    if (ep.find("password") != ep.end()) {
        ep["password"].clear();
    }
    return UriParser::toString(ep);
}

void MqttClient::connectionStatusChanged(const std::string& id, ConnectionStatus cs)
{
    LogDebug("[MqttClient::connectionStatusChanged]", std::string( m_clientId + " - connection status of wrapped client " + id + " changed to " + std::to_string( static_cast<int>(cs) ) ) );
    IMqttClientImpl* principalClient{ nullptr };
    std::vector<IMqttClientImpl*> clients{};
    std::vector<ConnectionStatus> connectionStates{};
    {
        OSDEV_COMPONENTS_LOCKGUARD(m_internalMutex);

        if (m_principalClient)
        {
            principalClient = m_principalClient.get();
            clients.push_back(principalClient);
            connectionStates.push_back(m_principalClient->connectionStatus());
        }

        for (const auto& c : m_additionalClients)
        {
            clients.push_back(c.get());
            connectionStates.push_back(c->connectionStatus());
        }
    }

    auto newState = determineState(connectionStates);
    // bool resubscribe = (StateEnum::ConnectionFailure == m_serverState.state() && StateEnum::Good == newState);
    bool resubscribe = ( StateEnum::Good == newState );
    if (resubscribe)
    {
        // First activate pending subscriptions
        {
            OSDEV_COMPONENTS_LOCKGUARD(m_subscriptionMutex);
            LogDebug( "[MqttClient::connectionsStatusChanged]", std::string( m_clientId + " - Number of pending subscriptions : " + std::to_string(m_deferredSubscriptions.size() ) ) );
            while( m_deferredSubscriptions.size() > 0 )
            {
                auto subscription = m_deferredSubscriptions.at( 0 );
                this->subscribe( subscription.getTopic(), subscription.getQoS(), subscription.getCallBack() );
                m_deferredSubscriptions.erase( m_deferredSubscriptions.begin() );
            }
        }

        LogDebug( "[MqttClient::connectionStatusChanged]",
                  std::string( m_clientId + " - Resubscribing..." ) );
        {
            OSDEV_COMPONENTS_LOCKGUARD(m_internalMutex);
            m_activeTokens.clear();
        }

        for (auto* cl : clients)
        {
            try
            {
                LogDebug( "[MqttClient::connectionStatusChanged]", std::string( m_clientId + " - Client " + cl->clientId() + " has " + std::string( cl->hasPendingSubscriptions() ? "" : "no" ) + " pending subscriptions" ) );
                cl->resubscribe();
            }
            catch (const std::exception& e)
            {
                LogError("[MqttClient::connectionStatusChanged]", std::string( m_clientId + " - resubscribe on wrapped client " + cl->clientId() + " in context of connection status change in wrapped client : " + id + " => FAILED : " + e.what() ) );
            }
        }
        m_activeTokensCV.notify_all();
    }

    // The server state change and a possible resubscription are done in the context of the MqttClient worker thread
    // The wrapper is free to pick up new work such as the acknowledment of the just recreated subscriptions.
    this->pushEvent([this, resubscribe, clients, principalClient, newState]() {
        if (resubscribe)
        {
            // Just wait for the subscription commands to complete. We do not use waitForCompletionInternal because that call will always timeout when there are active tokens.
            // Active tokens are removed typically by work done on the worker thread. The wait action is also performed on the worker thread.
            auto waitFor = std::chrono::milliseconds(1000);
            if (!waitForCompletionInternalClients(clients, waitFor, std::set<Token>{}))
            {
                if (std::accumulate(clients.begin(),
                                    clients.end(),
                                    false,
                                    [](bool hasPending, IMqttClientImpl* client)
                                    {
                                        return hasPending || client->hasPendingSubscriptions();
                                    }))
                {
                    LogWarning("[MqttClient::connectionStatusChanged]", std::string( m_clientId + " - subscriptions are not recovered within timeout." ) );
                }
            }
            if (principalClient)
            {
                try
                {
                    principalClient->publishPending();
                }
                catch (const std::exception& e)
                {
                    LogError( "[MqttClient::connectionStatusChanged]", std::string( m_clientId + " - publishPending on wrapped client " + principalClient->clientId() + " => FAILED " + e.what() ) );
                }
            }
        }
        m_serverState.emitStateChanged(newState);
    });
}

void MqttClient::deliveryComplete(const std::string& _clientId, std::int32_t token)
{
    if (!m_deliveryCompleteCallback) {
        return;
    }

    Token t(_clientId, token);
    {
        OSDEV_COMPONENTS_LOCKGUARD(m_internalMutex);
        if (!m_activeTokens.insert(t).second) {
            // This should not happen. This means that some callback on the wrapper never came.
            LogDebug("[MqttClient::deliveryComplete]", std::string( m_clientId + " - deliveryComplete, token is already active" ) );
        }
    }
    this->pushEvent([this, t]() {
        OSDEV_COMPONENTS_SCOPEGUARD(m_activeTokens, [this, &t]() {
            {
                OSDEV_COMPONENTS_LOCKGUARD(m_internalMutex);
                m_activeTokens.erase(t);
            }
            m_activeTokensCV.notify_all();
        });
        m_deliveryCompleteCallback(t);
    });
}

bool MqttClient::waitForCompletionInternal(const std::vector<IMqttClientImpl*>& clients, std::chrono::milliseconds waitFor, const std::set<Token>& tokens) const
{
    if (!waitForCompletionInternalClients(clients, waitFor, tokens)) {
        return false;
    }
    std::unique_lock<std::mutex> lck(m_internalMutex);
    return m_activeTokensCV.wait_for(lck, waitFor, [this, &tokens]() {
        if (tokens.empty()) { // wait for all operations to end
            return m_activeTokens.empty();
        }
        else if (tokens.size() == 1) {
            return m_activeTokens.find(*tokens.cbegin()) == m_activeTokens.end();
        }
        std::vector<Token> intersect{};
        std::set_intersection(m_activeTokens.begin(), m_activeTokens.end(), tokens.begin(), tokens.end(), std::back_inserter(intersect));
        return intersect.empty(); });
}

bool MqttClient::waitForCompletionInternalClients(const std::vector<IMqttClientImpl*>& clients, std::chrono::milliseconds& waitFor, const std::set<Token>& tokens) const
{
    for (auto* c : clients) {
        std::set<std::int32_t> clientTokens{};
        for (const auto& token : tokens) {
            if (c->clientId() == token.clientId()) {
                clientTokens.insert(token.token());
            }
        }
        if (tokens.empty() || !clientTokens.empty()) {
            waitFor -= c->waitForCompletion(waitFor, clientTokens);
        }
    }
    if (waitFor > std::chrono::milliseconds(0)) {
        return true;
    }
    waitFor = std::chrono::milliseconds(0);
    return false;
}

StateEnum MqttClient::determineState(const std::vector<ConnectionStatus>& connectionStates)
{
    std::size_t unknownStates = 0;
    std::size_t goodStates = 0;
    std::size_t reconnectStates = 0;
    for (auto cst : connectionStates) {
        switch (cst) {
            case ConnectionStatus::Disconnected:
                ++unknownStates;
                break;
            case ConnectionStatus::DisconnectInProgress: // count as unknown because we don't want resubscribes to trigger when a wrapper is in this state.
                ++unknownStates;
                break;
            case ConnectionStatus::ConnectInProgress: // count as unknown because the wrapper is not connected yet.
                ++unknownStates;
                break;
            case ConnectionStatus::ReconnectInProgress:
                ++reconnectStates;
                break;
            case ConnectionStatus::Connected:
                ++goodStates;
                break;
        }
    }
    auto newState = StateEnum::Unknown;
    if (reconnectStates > 0) {
        newState = StateEnum::ConnectionFailure;
    }
    else if (unknownStates > 0) {
        newState = StateEnum::Unknown;
    }
    else if (connectionStates.size() == goodStates) {
        newState = StateEnum::Good;
    }
    return newState;
}

void MqttClient::pushEvent(std::function<void()> ev)
{
    m_eventQueue.push(ev);
}

void MqttClient::eventHandler()
{
    LogInfo("[MqttClient::eventHandler]", std::string( m_clientId + " - starting event handler." ) );
    for (;;)
    {
        std::vector<std::function<void()>> events;
        if (!m_eventQueue.pop(events))
        {
            break;
        }
        for (const auto& ev : events)
        {
            ev();
        }
    }
    LogInfo("[MqttClient::eventHandler]", std::string( m_clientId + " - leaving event handler." ) );
}

void MqttClient::setMask(log::LogMask logMask )
{
    Log::setMask( logMask );
}

void MqttClient::setLogLevel(log::LogLevel logLevel)
{
    Log::setLogLevel( logLevel );
}

void MqttClient::setContext(std::string context)
{
    Log::setContext( context );
}