/******************************************************************************
*
*	CAEN SpA - Software Division
*	Via Vetraia, 11 - 55049 - Viareggio ITALY
*	+39 0594 388 398 - www.caen.it
*
*******************************************************************************
*
*	Copyright (C) 2019-2022 CAEN SpA
*
*	This file is part of the CAEN Utility.
*
*	The CAEN Utility is free software; you can redistribute it and/or
*	modify it under the terms of the GNU Lesser General Public
*	License as published by the Free Software Foundation; either
*	version 3 of the License, or (at your option) any later version.
*
*	The CAEN Utility is distributed in the hope that it will be useful,
*	but WITHOUT ANY WARRANTY; without even the implied warranty of
*	MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
*	Lesser General Public License for more details.
*
*	You should have received a copy of the GNU Lesser General Public
*	License along with the CAEN Utility; if not, see
*	https://www.gnu.org/licenses/.
*
*	SPDX-License-Identifier: LGPL-3.0-or-later
*
***************************************************************************//*!
*
*	\file		CAENSocket.c
*	\brief		TCP/IP functions.
*	\author
*
******************************************************************************/

#ifdef _WIN32
#include <WS2tcpip.h>
#endif

#include <CAENSocket.h>

#include <limits.h>

#ifndef _WIN32
#include <netdb.h>
#include <arpa/inet.h> // inet_ntop
#include <sys/socket.h> // setsockopt, socklen_t, AF_INET
#include <unistd.h> // close
#endif

#include <CAENLogger.h>
#include <CAENMultiplatform.h>
#include <CAENThread.h>

INIT_C_LOGGER("CAENSocket.log", "CAENSocket.c");

#ifdef _WIN32
static void _printWinError(const char *function, DWORD hResult) {
	// see https://stackoverflow.com/a/455533/3287591
	LPSTR errorText = NULL;
	const DWORD r = FormatMessageA(
		FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_IGNORE_INSERTS,
		NULL,
		hResult,
		MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT),
		(LPSTR)&errorText,
		0,
		NULL
	);
	if (r == 0) {
		// FormatMessage failed
		logMsg(c_logger_Severity_ERROR, "%s() failed: %"PRIu32".", function, (uint32_t)hResult);
		return;
	}
	logMsg(c_logger_Severity_ERROR, "%s() failed: %s.", function, errorText);
	LocalFree(errorText);
}
#endif

int32_t c_socket_init(void) {
	int ret = c_Socket_ErrorCode_Success;
#ifdef _WIN32
	WSADATA wsaData;
	const int r = WSAStartup(WINSOCK_VERSION, &wsaData);
	switch (r) {
	case 0:
		break;
	case WSASYSNOTREADY:
	case WSAVERNOTSUPPORTED:
	case WSAEINPROGRESS:
	case WSAEPROCLIM:
	case WSAEFAULT:
	default:
		_printWinError("WSAStartup", r);
		ret = c_Socket_ErrorCode_Init;
		break;
	}
#else
#endif
	return ret;
}

int32_t c_socket_cleanup(void) {
	int ret = c_Socket_ErrorCode_Success;
#ifdef _WIN32
	const int r = WSACleanup();
	switch (r) {
	case 0:
		break;
	case WSANOTINITIALISED:
	case WSAENETDOWN:
	case WSAEINPROGRESS:
	default:
		_printWinError("WSACleanup", r);
		ret = c_Socket_ErrorCode_Init;
		break;
	}
#else
#endif
	return ret;
}

c_use_decl_annotations c_socket_t* c_socket_new(void) {
	c_socket_t *sckt = c_malloc(sizeof(*sckt));
	if (sckt == NULL)
		return NULL;
	sckt->socket = c_socket_invalid;
	c_socket_init(); /// WSAStartup
	c_mutex_init(&sckt->mutex);
	return sckt;
}

void c_socket_delete(c_socket_t *sckt) {
	if (c_unlikely(sckt == NULL))
		return;
	c_socket_reset(sckt);
	c_socket_cleanup(); // WSACleanup must match each call to WSAStartup
	c_mutex_destroy(&sckt->mutex);
	c_free(sckt);
}

void c_socket_reset(c_socket_t *sckt) {
	if (c_unlikely(sckt == NULL))
		return;
	if (sckt->socket == c_socket_invalid)
		return;

	int iret;

	// On Windows socket() returns a handle to a kernel object, so it must be closed with closesocket().
#ifdef _WIN32
	iret = closesocket(sckt->socket);
#else
	iret = close(sckt->socket);
#endif

	if (iret == c_socket_error) {
		logMsg(c_logger_Severity_ERROR, "%s(): close/closesocket failed. Error %d.", __func__, c_socket_errno);
	}

	sckt->socket = c_socket_invalid;
}

c_ssize_t c_recv(const c_socket_t *sckt, void *buffer, size_t totSize) {
	if (c_unlikely(sckt == NULL)) {
		logMsg(c_logger_Severity_ERROR, "%s(): Invalid c_socket_t.", __func__);
		return -1;
	}

	/*
	 * Windows:
	 * Size parameter type of recv is int: the cast from size_t to int may introduce sneaky bugs
	 * if totSize > INT_MAX.
	 *
	 * Linux:
	 * Size parameter type of recv is size_t, but the maximum accepted value is SSIZE_MAX: the
	 * cast from size_t to ssize_t may introduce sneaky bugs if totSize > SSIZE_MAX.
	 */
	if (c_unlikely(totSize > c_socket_max_ssize)) {
		logMsg(c_logger_Severity_WARNING, "%s(): Packet too large (size '%zu' > max '%zu' bytes).", __func__, totSize, c_socket_max_ssize);
		return c_Socket_ErrorCode_Send;
	}

	const c_ssize_t totalToRead = (c_ssize_t)totSize; // totSize now can be casted safely
	c_ssize_t stillToRead = totalToRead;

	char *buffer_temp = (char*)buffer;

	do {
		const c_ssize_t nbytes = recv(sckt->socket, buffer_temp, stillToRead, 0);
		if (nbytes == c_socket_error) {
			logMsg(c_logger_Severity_ERROR, "%s(): recv() failed. Error %d.", __func__, c_socket_errno);
			return nbytes;
		}
		else if (nbytes == 0) {
			// recv returning 0 means a graceful disconnection on both POSIX and Windows.
			// cast to intmax_t required because c_ssize_t specifier is system dependent.
			logMsg(c_logger_Severity_INFO, "%s(): client disconnected (requested=%"PRIdMAX", remaining=%"PRIdMAX").", __func__, (intmax_t)totalToRead, (intmax_t)stillToRead);
			return totalToRead - stillToRead; // Anyway, return the actual number of bytes read
		}
		stillToRead -= nbytes;
		buffer_temp += nbytes;
	} while (stillToRead > 0);

	if (stillToRead != 0) {
		logMsg(c_logger_Severity_ERROR, "%s(): Message not completely read.", __func__);
	}

	return totalToRead - stillToRead;
}

c_ssize_t c_recv_unlock(c_socket_t *sckt, void *buffer, size_t totSize) {
	if (c_unlikely(sckt == NULL)) {
		logMsg(c_logger_Severity_ERROR, "%s(): Invalid c_socket_t.", __func__);
		return -1;
	}

	const c_ssize_t ret = c_recv(sckt, buffer, totSize);

	if (c_mutex_unlock(&sckt->mutex) != CAENThread_RetCode_Success) {
		logMsg(c_logger_Severity_ERROR, "%s(): Error in c_mutex_unlock().", __func__);
	}

	return ret;
}

c_ssize_t c_send(const c_socket_t *sckt, const void *buffer, size_t totSize) {
	if (c_unlikely(sckt == NULL)) {
		logMsg(c_logger_Severity_ERROR, "%s(): Invalid c_socket_t.", __func__);
		return -1;
	}

	/*
	 * Windows:
	 * Size parameter type of send is int: the cast from size_t to int may introduce sneaky bugs
	 * if totSize > INT_MAX.
	 * 
	 * Linux:
	 * Size parameter type of send is size_t, but the maximum accepted value is SSIZE_MAX: the
	 * cast from size_t to ssize_t may introduce sneaky bugs if totSize > SSIZE_MAX.
	 */
	if (c_unlikely(totSize > c_socket_max_ssize)) {
		logMsg(c_logger_Severity_WARNING, "%s(): Packet too large (size '%zu' > max '%zu' bytes).", __func__, totSize, c_socket_max_ssize);
		return c_Socket_ErrorCode_Send;
	}

	const c_ssize_t totalToSend = (c_ssize_t)totSize; // totSize now can be casted safely
	c_ssize_t stillToSend = totalToSend;

	const char* buffer_temp = (const char*)buffer;

	do {
		const c_ssize_t nbytes = send(sckt->socket, buffer_temp, stillToSend, 0);
		if (nbytes == c_socket_error) {
			logMsg(c_logger_Severity_ERROR, "%s(): send() failed. Error %d.", __func__, c_socket_errno);
			return nbytes;
		}
		stillToSend -= nbytes;
		buffer_temp += nbytes;
	} while (stillToSend > 0);

	if (stillToSend != 0) {
		logMsg(c_logger_Severity_ERROR, "%s(): Message not completely sent.", __func__);
	}

	return totalToSend - stillToSend;
}

c_ssize_t c_send_lock(c_socket_t *sckt, const void *buffer, size_t totSize) {
	if (c_unlikely(sckt == NULL)) {
		logMsg(c_logger_Severity_ERROR, "%s(): Invalid c_socket_t.", __func__);
		return -1;
	}

	if (c_mutex_lock(&sckt->mutex) != CAENThread_RetCode_Success) {
		logMsg(c_logger_Severity_ERROR, "%s(): Error in c_mutex_lock().", __func__);
		return c_socket_error;
	}

	return c_send(sckt, buffer, totSize);
}

c_use_decl_annotations c_socket_t* c_tcp_socket(void) {
	c_socket_t *sckt = c_socket_new();
	if (sckt == NULL) {
		logMsg(c_logger_Severity_ERROR, "%s(): Error in c_socket_new().", __func__);
		return sckt;
	}
	sckt->socket = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
	return sckt;
}

int c_bind(const c_socket_t *sckt, const struct sockaddr *addr, c_socklen_t addrlen) {
	return bind(sckt->socket, addr, addrlen);
}

int c_listen(const c_socket_t *sckt, int backlog) {
	return listen(sckt->socket, backlog);
}

c_use_decl_annotations c_socket_t* c_accept(const c_socket_t *sckt, struct sockaddr *addr, c_socklen_t*addrlen) {
	c_socket_t *client = c_socket_new();
	if (client == NULL) {
		logMsg(c_logger_Severity_ERROR, "%s(): Error in c_socket_new().", __func__);
		return client;
	}
	client->socket = accept(sckt->socket, addr, addrlen);
	return client;
}

int c_connect(const c_socket_t *sckt, const struct sockaddr *addr, c_socklen_t addrlen) {
	return connect(sckt->socket, addr, addrlen);
}


int32_t c_socket_server_init(c_socket_t **server, uint32_t inaddr, uint16_t *_port) {
	int32_t ret = c_Socket_ErrorCode_Success;
	int iret;
	struct sockaddr_in hints = { 0 };
	int on = 1;
	struct sockaddr_in addr_server = { 0 };
	c_socklen_t addr_server_len = sizeof(addr_server);
	uint16_t port = *_port;

	if (server == NULL)
		return c_Socket_ErrorCode_Argument;

	c_socket_t *server_local = c_tcp_socket();
	if (server_local == NULL || server_local->socket == c_socket_invalid) {
		logMsg(c_logger_Severity_ERROR, "%s(): socket() failed. Error %d.", __func__, c_socket_errno);
		ret = c_Socket_ErrorCode_Init;
		goto QuitFunction;
	}
	logMsg(c_logger_Severity_DEBUG, "Socket created.");

	iret = setsockopt(server_local->socket, SOL_SOCKET, SO_REUSEADDR, (char*)&on, sizeof(on));
	if (iret == c_socket_error) {
		logMsg(c_logger_Severity_ERROR, "%s(): setsockopt() failed. Error %d.", __func__, c_socket_errno);
		ret = c_Socket_ErrorCode_Init;
		goto QuitFunction;
	}

	hints.sin_family = AF_INET;
	hints.sin_addr.s_addr = c_hton32(inaddr);
	hints.sin_port = c_hton16(port);
	iret = c_bind(server_local, (struct sockaddr*)&hints, sizeof(hints));
	if (iret == c_socket_error) {
		logMsg(c_logger_Severity_ERROR, "%s(): bind() failed. Error %d.", __func__, c_socket_errno);
		ret = c_Socket_ErrorCode_Bind;
		goto QuitFunction;
	}
	logMsg(c_logger_Severity_DEBUG, "Socket bind success.");

	iret = c_listen(server_local, SOMAXCONN);
	if (iret == c_socket_error) {
		logMsg(c_logger_Severity_ERROR, "%s(): listen() failed. Error %d.", __func__, c_socket_errno);
		ret = c_Socket_ErrorCode_Listen;
		goto QuitFunction;
	}

	iret = getsockname(server_local->socket, (struct sockaddr*)&addr_server, &addr_server_len);
	if (iret == c_socket_error) {
		if (port == 0) {
			logMsg(c_logger_Severity_ERROR, "getsockname() failed with error %d. This call is mandatory because port is automatic.", c_socket_errno);
			ret = c_Socket_ErrorCode_GenericError;
			goto QuitFunction;
		}
		else {
			logMsg(c_logger_Severity_WARNING, "Ready to receive connections. getsockname() failed. Error %d. Going on anyway.", c_socket_errno);
		}
	}
	else {
		char buf[INET_ADDRSTRLEN];
		if (inet_ntop(AF_INET, &addr_server.sin_addr, buf, sizeof(buf)) != NULL)
			logMsg(c_logger_Severity_INFO, "Ready to receive connections from '%s' on port '%"PRIu16"'.", buf, c_ntoh16(addr_server.sin_port));
		else
			logMsg(c_logger_Severity_INFO, "Ready to receive connections on port '%"PRIu16"'. inet_ntop() failed. Error %d. Going on anyway.", c_ntoh16(addr_server.sin_port), c_socket_errno);
		if (port == 0)
			*_port = c_ntoh16(addr_server.sin_port);
	}

QuitFunction:
	if (ret != c_Socket_ErrorCode_Success) {
		logMsg(c_logger_Severity_INFO, "%s(): failed. Error %"PRIi32, __func__, ret);
		// TODO CHECK more cleanup needed here?
		c_socket_delete(server_local);
	}
	else {
		*server = server_local;
	}
	return ret;
}

int32_t c_socket_server_accept(const c_socket_t *server, c_socket_t **client) {
	if (server == NULL || client == NULL)
		return c_Socket_ErrorCode_GenericError;

	struct sockaddr_in addr_client = { 0 };
	c_socklen_t addr_client_len = sizeof(addr_client);

	*client = NULL;
	logMsg(c_logger_Severity_INFO, "Waiting for connection...");

	c_socket_t *client_local = c_accept(server, (struct sockaddr*)&addr_client, &addr_client_len);
	if (client_local == NULL || client_local->socket == c_socket_invalid) {
		logMsg(c_logger_Severity_ERROR, "%s(): accept() failed. Error %d.", __func__, c_socket_errno);
		if (client_local != NULL)
			c_socket_delete(client_local);
		return c_Socket_ErrorCode_GenericError;
	}

	*client = client_local;

	// Client connected
	char buf[INET_ADDRSTRLEN];
	if (inet_ntop(AF_INET, &addr_client.sin_addr, buf, sizeof(buf)) != NULL)
		logMsg(c_logger_Severity_INFO, "Connected to '%s'.", buf);
	else
		logMsg(c_logger_Severity_WARNING, "Connected to a client. inet_ntop() failed. Error %d. Going on anyway.", c_socket_errno);

	return c_Socket_ErrorCode_Success;
}

int32_t c_socket_client_sockaddr_connect(c_socket_t **client, const struct sockaddr *addr_server, c_socklen_t addrlen) {
	int32_t ret = c_Socket_ErrorCode_Success;
	int iret;

	if (client == NULL)
		return c_Socket_ErrorCode_Argument;

	c_socket_t *client_local = c_tcp_socket();
	if (client_local == NULL || client_local->socket == c_socket_invalid) {
		logMsg(c_logger_Severity_ERROR, "socket() failed. Error %d.", c_socket_errno);
		ret = c_Socket_ErrorCode_Init;
		goto QuitFunction;
	}
	logMsg(c_logger_Severity_INFO, "Socket created.");

	int on = 1;
	iret = setsockopt(client_local->socket, SOL_SOCKET, SO_REUSEADDR, (char*)&on, sizeof(on));
	if (iret == c_socket_error) {
		logMsg(c_logger_Severity_ERROR, "%s(): setsockopt() failed. Error %d.", __func__, c_socket_errno);
		ret = c_Socket_ErrorCode_Init;
		goto QuitFunction;
	}

	iret = c_connect(client_local, addr_server, addrlen);
	if (iret == c_socket_error) {
		logMsg(c_logger_Severity_ERROR, "%s(): connect() failed. Error %d.", __func__, c_socket_errno);
		ret = c_Socket_ErrorCode_Connect;
		goto QuitFunction;
	}

	logMsg(c_logger_Severity_INFO, "Connected");

QuitFunction:
	if (ret != c_Socket_ErrorCode_Success) {
		logMsg(c_logger_Severity_INFO, "%s(): failed. Error %"PRIi32, __func__, ret);
		c_socket_delete(client_local);
		client_local = NULL;
	}
	*client = client_local;
	return ret;
}

int32_t c_socket_client_connect(c_socket_t **client, const char *hostname, uint16_t port) {
	struct sockaddr_in addr_server = { 0 };
	int32_t ret = c_Socket_ErrorCode_Success;

	if (client == NULL)
		return c_Socket_ErrorCode_Argument;

	// resolve hostname
	if ((ret = c_socket_init()) != c_Socket_ErrorCode_Success)
		return ret;
	struct hostent *he = gethostbyname(hostname);
	if ((ret = c_socket_cleanup()) != c_Socket_ErrorCode_Success)
		return ret;
	if (he == NULL) {
		logMsg(c_logger_Severity_ERROR, "gethostbyname() failed. Error %d.", c_socket_h_errno);
		ret = c_Socket_ErrorCode_DNS;
		goto QuitFunction;
	}

	c_memcpy(&addr_server.sin_addr.s_addr, he->h_addr, he->h_length);
	addr_server.sin_family = AF_INET;
	addr_server.sin_port = c_hton16(port);

	if ((ret = c_socket_client_sockaddr_connect(client, (struct sockaddr*)&addr_server, sizeof(addr_server))) != c_Socket_ErrorCode_Success) {
		logMsg(c_logger_Severity_ERROR, "%s(): connect() to '%s:%"PRIu16"' failed. Error %"PRIi32".", __func__, hostname, c_ntoh16(addr_server.sin_port), ret);
		goto QuitFunction;
	}

	logMsg(c_logger_Severity_INFO, "Connected to '%s:%"PRIu16"'.", hostname, c_ntoh16(addr_server.sin_port));

QuitFunction:
	if (ret != c_Socket_ErrorCode_Success)
		logMsg(c_logger_Severity_INFO, "%s(): failed. Error %"PRIi32".", __func__, ret);
	return ret;
}
