/* ****************************************************************************
 * Copyright 2019 Open Systems Development BV                                 *
 *                                                                            *
 * Permission is hereby granted, free of charge, to any person obtaining a    *
 * copy of this software and associated documentation files (the "Software"), *
 * to deal in the Software without restriction, including without limitation  *
 * the rights to use, copy, modify, merge, publish, distribute, sublicense,   *
 * and/or sell copies of the Software, and to permit persons to whom the      *
 * Software is furnished to do so, subject to the following conditions:       *
 *                                                                            *
 * The above copyright notice and this permission notice shall be included in *
 * all copies or substantial portions of the Software.                        *
 *                                                                            *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR *
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,   *
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL    *
 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER *
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING    *
 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER        *
 * DEALINGS IN THE SOFTWARE.                                                  *
 * ***************************************************************************/
#include "mqttclient.h"

// osdev::components::mqtt
#include "clientpaho.h"
#include "mqttutil.h"
#include "mqttidgenerator.h"
#include "mqtttypeconverter.h"
#include "log.h"
#include "lockguard.h"
#include "uriparser.h"

// std
#include <numeric>
#include <iostream>

using namespace osdev::components::mqtt;

namespace {
/**
 * @brief Generate a unique client id so that a new client does not steal the connection from an existing client.
 * @param clientId The base client identifier.
 * @param clientNumber The next client that is derived from the base client identifier.
 * @return A unique client identifier string.
 */
std::string generateUniqueClientId(const std::string& clientId, std::size_t clientNumber)
{
    return clientId + "_" + std::to_string(clientNumber) + "_" + MqttTypeConverter::toStdString(MqttIdGenerator::generate());
}

} // namespace

MqttClient::MqttClient(const std::string& _clientId, const std::function<void(const Token& token)>& deliveryCompleteCallback)
    : m_interfaceMutex()
    , m_internalMutex()
    , m_endpoint()
    , m_clientId(_clientId)
    , m_activeTokens()
    , m_activeTokensCV()
    , m_deliveryCompleteCallback(deliveryCompleteCallback)
    , m_serverState(this)
    , m_principalClient()
    , m_additionalClients()
    , m_eventQueue(_clientId)
    , m_workerThread( std::thread( &MqttClient::eventHandler, this ) )
{
    // Initialize the logger.
    Log::init( "mqtt-cpp", "", LogLevel::Info );
}

MqttClient::~MqttClient()
{
    LogDebug( "MqttClient", std::string( m_clientId + " - disconnect" ) );
    {
        LogDebug( "MqttClient", std::string( m_clientId + " - disconnect" ) );
        this->disconnect();
        decltype(m_principalClient) principalClient{};

        OSDEV_COMPONENTS_LOCKGUARD(m_internalMutex);
        LogDebug( "MqttClient", std::string( m_clientId + " - cleanup principal client" ) );
        m_principalClient.swap(principalClient);
    }

    LogDebug( "MqttClient", std::string( m_clientId + " - dtor stop queue" ) );
    m_eventQueue.stop();
    if (m_workerThread.joinable()) {
        m_workerThread.join();
    }
    LogDebug( "MqttClient", std::string( m_clientId + " - dtor ready" ) );
}

std::string MqttClient::clientId() const
{
    return m_clientId;
}

StateChangeCallbackHandle MqttClient::registerStateChangeCallback(const SlotStateChange& cb)
{
    OSDEV_COMPONENTS_LOCKGUARD(m_interfaceMutex);
    OSDEV_COMPONENTS_LOCKGUARD(m_internalMutex);
    return m_serverState.registerStateChangeCallback(cb);
}

void MqttClient::unregisterStateChangeCallback(StateChangeCallbackHandle handle)
{
    OSDEV_COMPONENTS_LOCKGUARD(m_interfaceMutex);
    OSDEV_COMPONENTS_LOCKGUARD(m_internalMutex);
    m_serverState.unregisterStateChangeCallback(handle);
}

void MqttClient::clearAllStateChangeCallbacks()
{
    OSDEV_COMPONENTS_LOCKGUARD(m_interfaceMutex);
    OSDEV_COMPONENTS_LOCKGUARD(m_internalMutex);
    m_serverState.clearAllStateChangeCallbacks();
}

StateEnum MqttClient::state() const
{
    OSDEV_COMPONENTS_LOCKGUARD(m_interfaceMutex);
    OSDEV_COMPONENTS_LOCKGUARD(m_internalMutex);
    return m_serverState.state();
}

void MqttClient::connect(const std::string& host, int port, const Credentials& credentials)
{

    osdev::components::mqtt::ParsedUri _endpoint = {
        { "scheme", "tcp" },
        { "user", credentials.username() },
        { "password", credentials.password() },
        { "host", host },
        { "port", std::to_string(port) }
    };
    this->connect(UriParser::toString(_endpoint));
}

void MqttClient::connect(const std::string& _endpoint)
{
    LogInfo( "MqttClient", std::string( m_clientId + " - Request connect" ) );

    OSDEV_COMPONENTS_LOCKGUARD(m_interfaceMutex);
    IMqttClientImpl* client(nullptr);
    {
        OSDEV_COMPONENTS_LOCKGUARD(m_internalMutex);
        if (m_principalClient && m_principalClient->connectionStatus() != ConnectionStatus::Disconnected) {
            if (_endpoint == m_endpoint)
            {
                // idempotent
                return;
            }
            else
            {
                LogError( "MqttClient", std::string( m_clientId + " - Cannot connect to different endpoint. Disconnect first." ) );
                return;
            }
        }
        m_endpoint = _endpoint;
        if (!m_principalClient) {
            std::string derivedClientId(generateUniqueClientId(m_clientId, 1));
            m_principalClient = std::make_unique<ClientPaho>(
                m_endpoint,
                derivedClientId,
                [this](const std::string& id, ConnectionStatus cs) { this->connectionStatusChanged(id, cs); },
                [this](std::string clientId, std::int32_t token) { this->deliveryComplete(clientId, token); });
        }
        client = m_principalClient.get();
    }
    client->connect(true);
}

void MqttClient::disconnect()
{
    LogInfo( "MqttClient", std::string( m_clientId + " - Request disconnect" ) );
    OSDEV_COMPONENTS_LOCKGUARD(m_interfaceMutex);

    decltype(m_additionalClients) additionalClients{};
    std::vector<IMqttClientImpl*> clients{};
    {
        OSDEV_COMPONENTS_LOCKGUARD(m_internalMutex);
        if (!m_principalClient || m_principalClient->connectionStatus() == ConnectionStatus::Disconnected || m_principalClient->connectionStatus() == ConnectionStatus::DisconnectInProgress)
        {
            return;
        }
        m_additionalClients.swap( additionalClients );

        for (const auto& c : additionalClients)
        {
            clients.push_back( c.get() );
        }
        clients.push_back( m_principalClient.get() );
    }


    LogDebug( "MqttClient", std::string( m_clientId +  " - Unsubscribe and disconnect clients" ) );
    // DebugLogToFIle ("MqttClient", "%1 - Unsubscribe and disconnect clients", m_clientId);
    for ( auto& cl : clients )
    {
        cl->unsubscribeAll();
    }
    this->waitForCompletionInternal(clients, std::chrono::milliseconds(2000), std::set<Token>{});

    for (auto& cl : clients) {
        cl->disconnect(false, 2000);
    }
    this->waitForCompletionInternal(clients, std::chrono::milliseconds(3000), std::set<Token>{});
}

Token MqttClient::publish(const MqttMessage& message, int qos)
{
    if (hasWildcard(message.topic()) || !isValidTopic(message.topic())) {
        return Token(m_clientId, -1);
    }
    OSDEV_COMPONENTS_LOCKGUARD(m_interfaceMutex);
    IMqttClientImpl* client(nullptr);
    {
        OSDEV_COMPONENTS_LOCKGUARD(m_internalMutex);
        if (!m_principalClient || m_principalClient->connectionStatus() == ConnectionStatus::Disconnected)
        {
            if( !m_principalClient )
            {
                std::cout << "Principal client not initialized" << std::endl;
            }

            if( m_principalClient->connectionStatus() == ConnectionStatus::Disconnected )
            {
                std::cout << "Unable to publish, not connected.." << std::endl;
            }
            // ErrorLogToFIle ("MqttClient", "%1 - Unable to publish, not connected", m_clientId);
            // Throw (MqttException, "Not connected");
        }
        client = m_principalClient.get();
    }
    return Token(client->clientId(), client->publish(message, qos));
}

Token MqttClient::subscribe(const std::string& topic, int qos, const std::function<void(MqttMessage)>& cb)
{
    // DebugLogToFIle ("MqttClient", "%1 - Subscribe to topic %2 with qos %3", m_clientId, topic, qos);
    // OSDEV_COMPONENTS_LOCKGUARD(m_interfaceMutex);
    bool clientFound = false;
    IMqttClientImpl* client(nullptr);
    {
        // OSDEV_COMPONENTS_LOCKGUARD(m_internalMutex);
        if (!m_principalClient || m_principalClient->connectionStatus() == ConnectionStatus::Disconnected)
        {
            // ErrorLogToFIle ("MqttClient", "%1 - Unable to subscribe, not connected", m_clientId);
            // throw (?)(MqttException, "Not connected");
        }
        if (!m_principalClient->isOverlapping(topic)) {
            client = m_principalClient.get();
            clientFound = true;
        }
        else {
            for (const auto& c : m_additionalClients) {
                if (!c->isOverlapping(topic)) {
                    client = c.get();
                    clientFound = true;
                    break;
                }
            }
        }
        if (!clientFound) {
            // DebugLogToFIle ("MqttClient", "%1 - Creating new ClientPaho instance for subscription on topic %2", m_clientId, topic);
            std::string derivedClientId(generateUniqueClientId(m_clientId, m_additionalClients.size() + 2)); // principal client is nr 1.
            m_additionalClients.emplace_back(std::make_unique<ClientPaho>(
                m_endpoint,
                derivedClientId,
                [this](const std::string& id, ConnectionStatus cs) { this->connectionStatusChanged(id, cs); },
                std::function<void(const std::string&, std::int32_t)>{}));
            client = m_additionalClients.back().get();
        }
    }
    if (!clientFound) {
        client->connect(true);
    }
    return Token{ client->clientId(), client->subscribe(topic, qos, cb) };
}

std::set<Token> MqttClient::unsubscribe(const std::string& topic, int qos)
{
    // DebugLogToFIle ("MqttClient", "%1 - Unsubscribe from topic %2 with qos %3", m_clientId, topic, qos);
    OSDEV_COMPONENTS_LOCKGUARD(m_interfaceMutex);
    std::vector<IMqttClientImpl*> clients{};
    {
        OSDEV_COMPONENTS_LOCKGUARD(m_internalMutex);
        if (!m_principalClient || m_principalClient->connectionStatus() == ConnectionStatus::Disconnected) {
            // ErrorLogToFIle ("MqttClient", "%1 - Unable to unsubscribe, not connected", m_clientId);
            // Throw (MqttException, "Not connected");
        }
        clients.push_back(m_principalClient.get());
        for (const auto& c : m_additionalClients) {
            clients.push_back(c.get());
        }
    }
    std::set<Token> tokens{};
    for (auto* c : clients) {
        auto token = c->unsubscribe(topic, qos);
        if (-1 != token) {
            tokens.emplace(Token{ c->clientId(), token });
        }
    }
    return tokens;
}

bool MqttClient::waitForCompletion(std::chrono::milliseconds waitFor) const
{
    return this->waitForCompletion(waitFor, std::set<Token>{});
}

bool MqttClient::waitForCompletion(std::chrono::milliseconds waitFor, const Token& token) const
{
    if (-1 == token.token()) {
        return false;
    }
    return this->waitForCompletion(waitFor, std::set<Token>{ token });
}

bool MqttClient::waitForCompletion(std::chrono::milliseconds waitFor, const std::set<Token>& tokens) const
{
    OSDEV_COMPONENTS_LOCKGUARD(m_interfaceMutex);
    std::vector<IMqttClientImpl*> clients{};
    {
        OSDEV_COMPONENTS_LOCKGUARD(m_internalMutex);
        if (m_principalClient) {
            clients.push_back(m_principalClient.get());
        }
        for (const auto& c : m_additionalClients) {
            clients.push_back(c.get());
        }
    }
    return waitForCompletionInternal(clients, waitFor, tokens);
}

boost::optional<bool> MqttClient::commandResult(const Token& token) const
{
    OSDEV_COMPONENTS_LOCKGUARD(m_interfaceMutex);
    std::vector<IMqttClientImpl*> clients{};
    {
        OSDEV_COMPONENTS_LOCKGUARD(m_internalMutex);
        if (m_principalClient) {
            clients.push_back(m_principalClient.get());
        }
        for (const auto& c : m_additionalClients) {
            clients.push_back(c.get());
        }
    }
    for (auto* c : clients) {
        if (token.clientId() == c->clientId()) {
            return c->operationResult(token.token());
        }
    }
    return boost::optional<bool>{};
}

std::string MqttClient::endpoint() const
{
    auto ep = UriParser::parse(m_endpoint);
    if (ep.find("user") != ep.end()) {
        ep["user"].clear();
    }
    if (ep.find("password") != ep.end()) {
        ep["password"].clear();
    }
    return UriParser::toString(ep);
}

void MqttClient::connectionStatusChanged(const std::string& id, ConnectionStatus cs)
{
    (void)id;
    (void)cs;
    // DebugLogToFIle ("MqttClient", "%1 - connection status of wrapped client %2 changed to %3", m_clientId, id, cs);
    IMqttClientImpl* principalClient{ nullptr };
    std::vector<IMqttClientImpl*> clients{};
    std::vector<ConnectionStatus> connectionStates{};
    {
        OSDEV_COMPONENTS_LOCKGUARD(m_internalMutex);
        if (!m_principalClient) {
            return;
        }
        if (m_principalClient) {
            principalClient = m_principalClient.get();
            clients.push_back(principalClient);
            connectionStates.push_back(m_principalClient->connectionStatus());
        }
        for (const auto& c : m_additionalClients) {
            clients.push_back(c.get());
            connectionStates.push_back(c->connectionStatus());
        }
    }
    auto newState = determineState(connectionStates);
    bool resubscribe = (StateEnum::ConnectionFailure == m_serverState.state() && StateEnum::Good == newState);
    if (resubscribe) {
        {
            OSDEV_COMPONENTS_LOCKGUARD(m_internalMutex);
            m_activeTokens.clear();
        }
        for (auto* cl : clients) {
            try {
                cl->resubscribe();
            }
            catch (const std::exception& e) {
                // ErrorLogToFIle ("MqttClient", "%1 - resubscribe on wrapped client %2 in context of connection status change in wrapped client %3 failed : %4", m_clientId, cl->clientId(), id, e.what());
            }
        }
        m_activeTokensCV.notify_all();
    }

    // The server state change and a possible resubscription are done in the context of the MqttClient worker thread
    // The wrapper is free to pick up new work such as the acknowledment of the just recreated subscriptions.
    this->pushEvent([this, resubscribe, clients, principalClient, newState]() {
        if (resubscribe) {
            // Just wait for the subscription commands to complete. We do not use waitForCompletionInternal because that call will always timeout when there are active tokens.
            // Active tokens are removed typically by work done on the worker thread. The wait action is also performed on the worker thread.
            auto waitFor = std::chrono::milliseconds(1000);
            if (!waitForCompletionInternalClients(clients, waitFor, std::set<Token>{})) {
                if (std::accumulate(clients.begin(), clients.end(), false, [](bool hasPending, IMqttClientImpl* client) { return hasPending || client->hasPendingSubscriptions(); })) {
                    // WarnLogToFIle ("MqttClient", "%1 - subscriptions are not recovered within timeout.", m_clientId)
                }
            }
            if (principalClient) {
                try {
                    principalClient->publishPending();
                }
                catch (const std::exception& e) {
                    // ErrorLogToFIle ("MqttClient", "%1 - publishPending on wrapped client %2 failed : %3", m_clientId, principalClient->clientId(), e.what());
                }
            }
        }
        m_serverState.emitStateChanged(newState);
    });
}

void MqttClient::deliveryComplete(const std::string& _clientId, std::int32_t token)
{
    if (!m_deliveryCompleteCallback) {
        return;
    }

    Token t(_clientId, token);
    {
        OSDEV_COMPONENTS_LOCKGUARD(m_internalMutex);
        if (!m_activeTokens.insert(t).second) {
            // This should not happen. This means that some callback on the wrapper never came.
            // ErrorLogToFIle ("MqttClient", "%1 -deliveryComplete, token %1 is already active", m_clientId, t);
        }
    }
    this->pushEvent([this, t]() {
        OSDEV_COMPONENTS_SCOPEGUARD(m_activeTokens, [this, &t]() {
            {
                OSDEV_COMPONENTS_LOCKGUARD(m_internalMutex);
                m_activeTokens.erase(t);
            }
            m_activeTokensCV.notify_all();
        });
        m_deliveryCompleteCallback(t);
    });
}

bool MqttClient::waitForCompletionInternal(const std::vector<IMqttClientImpl*>& clients, std::chrono::milliseconds waitFor, const std::set<Token>& tokens) const
{
    if (!waitForCompletionInternalClients(clients, waitFor, tokens)) {
        return false;
    }
    std::unique_lock<std::mutex> lck(m_internalMutex);
    return m_activeTokensCV.wait_for(lck, waitFor, [this, &tokens]() {
        if (tokens.empty()) { // wait for all operations to end
            return m_activeTokens.empty();
        }
        else if (tokens.size() == 1) {
            return m_activeTokens.find(*tokens.cbegin()) == m_activeTokens.end();
        }
        std::vector<Token> intersect{};
        std::set_intersection(m_activeTokens.begin(), m_activeTokens.end(), tokens.begin(), tokens.end(), std::back_inserter(intersect));
        return intersect.empty(); });
}

bool MqttClient::waitForCompletionInternalClients(const std::vector<IMqttClientImpl*>& clients, std::chrono::milliseconds& waitFor, const std::set<Token>& tokens) const
{
    for (auto* c : clients) {
        std::set<std::int32_t> clientTokens{};
        for (const auto& token : tokens) {
            if (c->clientId() == token.clientId()) {
                clientTokens.insert(token.token());
            }
        }
        if (tokens.empty() || !clientTokens.empty()) {
            waitFor -= c->waitForCompletion(waitFor, clientTokens);
        }
    }
    if (waitFor > std::chrono::milliseconds(0)) {
        return true;
    }
    waitFor = std::chrono::milliseconds(0);
    return false;
}

StateEnum MqttClient::determineState(const std::vector<ConnectionStatus>& connectionStates)
{
    std::size_t unknownStates = 0;
    std::size_t goodStates = 0;
    std::size_t reconnectStates = 0;
    for (auto cst : connectionStates) {
        switch (cst) {
            case ConnectionStatus::Disconnected:
                ++unknownStates;
                break;
            case ConnectionStatus::DisconnectInProgress: // count as unknown because we don't want resubscribes to trigger when a wrapper is in this state.
                ++unknownStates;
                break;
            case ConnectionStatus::ConnectInProgress: // count as unknown because the wrapper is not connected yet.
                ++unknownStates;
                break;
            case ConnectionStatus::ReconnectInProgress:
                ++reconnectStates;
                break;
            case ConnectionStatus::Connected:
                ++goodStates;
                break;
        }
    }
    auto newState = StateEnum::Unknown;
    if (reconnectStates > 0) {
        newState = StateEnum::ConnectionFailure;
    }
    else if (unknownStates > 0) {
        newState = StateEnum::Unknown;
    }
    else if (connectionStates.size() == goodStates) {
        newState = StateEnum::Good;
    }
    return newState;
}

void MqttClient::pushEvent(std::function<void()> ev)
{
    m_eventQueue.push(ev);
}

void MqttClient::eventHandler()
{
    // InfoLogToFIle ("MqttClient", "%1 - starting event handler", m_clientId);
    for (;;) {
        std::vector<std::function<void()>> events;
        if (!m_eventQueue.pop(events))
        {
            break;
        }
        for (const auto& ev : events)
        {
            ev();
        }
    }
    // InfoLogToFIle ("MqttClient", "%1 - leaving event handler", m_clientId);
}