Reuse code for TCP and UDP connection

Put prefix length validation before memcmp() to avoid overflow
This commit is contained in:
wh201906 2023-12-26 16:30:36 +08:00
commit 39866f9ed2
No known key found for this signature in database
2 changed files with 58 additions and 308 deletions

View file

@ -102,14 +102,25 @@ serial_port uart_open(const char *pcPortName, uint32_t speed) {
} }
str_lower(prefix); str_lower(prefix);
if (memcmp(prefix, "tcp:", 4) == 0) { bool isTCP = false;
free(prefix); bool isUDP = false;
bool isBluetooth = false;
if (strlen(pcPortName) <= 4) { bool isUnixSocket = false;
PrintAndLogEx(ERR, "error: tcp port name length too short"); if (strlen(prefix) > 4)
free(sp); {
return INVALID_SERIAL_PORT; isTCP = (memcmp(prefix, "tcp:", 4) == 0);
isUDP = (memcmp(prefix, "udp:", 4) == 0);
} }
if (strlen(prefix) > 3) {
isBluetooth = (memcmp(prefix, "bt:", 3) == 0);
}
if (strlen(prefix) > 7) {
isUnixSocket = (memcmp(prefix, "socket:", 7) == 0);
}
if (isTCP || isUDP) {
free(prefix);
struct addrinfo *addr = NULL, *rp; struct addrinfo *addr = NULL, *rp;
@ -167,7 +178,7 @@ serial_port uart_open(const char *pcPortName, uint32_t speed) {
return INVALID_SERIAL_PORT; return INVALID_SERIAL_PORT;
} }
g_conn.send_via_ip = isIPv6 ? PM3_TCPv6 : PM3_TCPv4; g_conn.send_via_ip = isIPv6 ? (isTCP ? PM3_TCPv6 : PM3_UDPv6) : (isTCP ? PM3_TCPv4 : PM3_UDPv4);
portStr = (portStr == NULL) ? "18888" : portStr; portStr = (portStr == NULL) ? "18888" : portStr;
struct addrinfo info; struct addrinfo info;
@ -175,7 +186,7 @@ serial_port uart_open(const char *pcPortName, uint32_t speed) {
memset(&info, 0, sizeof(info)); memset(&info, 0, sizeof(info));
info.ai_family = PF_UNSPEC; info.ai_family = PF_UNSPEC;
info.ai_socktype = SOCK_STREAM; info.ai_socktype = isTCP ? SOCK_STREAM : SOCK_DGRAM;
if ((strstr(addrStr, "localhost") != NULL) || if ((strstr(addrStr, "localhost") != NULL) ||
(strstr(addrStr, "127.0.0.1") != NULL) || (strstr(addrStr, "127.0.0.1") != NULL) ||
@ -185,7 +196,7 @@ serial_port uart_open(const char *pcPortName, uint32_t speed) {
int s = getaddrinfo(addrStr, portStr, &info, &addr); int s = getaddrinfo(addrStr, portStr, &info, &addr);
if (s != 0) { if (s != 0) {
PrintAndLogEx(ERR, "error: getaddrinfo: %s", gai_strerror(s)); PrintAndLogEx(ERR, "error: getaddrinfo: %d: %s", s, gai_strerror(s));
freeaddrinfo(addr); freeaddrinfo(addr);
free(addrPortStr); free(addrPortStr);
free(sp); free(sp);
@ -225,148 +236,26 @@ serial_port uart_open(const char *pcPortName, uint32_t speed) {
sp->fd = sfd; sp->fd = sfd;
if (isTCP) {
int one = 1; int one = 1;
int res = setsockopt(sp->fd, SOL_TCP, TCP_NODELAY, &one, sizeof(one)); int res = setsockopt(sp->fd, SOL_TCP, TCP_NODELAY, &one, sizeof(one));
if (res != 0) { if (res != 0) {
free(sp); free(sp);
return INVALID_SERIAL_PORT; return INVALID_SERIAL_PORT;
} }
} else if (isUDP) {
return sp;
}
if (memcmp(prefix, "udp:", 4) == 0) {
free(prefix);
if (strlen(pcPortName) <= 4) {
PrintAndLogEx(ERR, "error: udp port name length too short");
free(sp);
return INVALID_SERIAL_PORT;
}
struct addrinfo *addr = NULL, *rp;
char *addrPortStr = str_dup(pcPortName + 4);
if (addrPortStr == NULL) {
PrintAndLogEx(ERR, "error: string duplication");
free(sp);
return INVALID_SERIAL_PORT;
}
timeout.tv_usec = UART_NET_CLIENT_RX_TIMEOUT_MS * 1000;
// find the "bind" option
char *bindAddrPortStr = strstr(addrPortStr, ",bind=");
const char *bindAddrStr = NULL;
const char *bindPortStr = NULL;
bool isBindingIPv6 = false;
if (bindAddrPortStr != NULL) {
*bindAddrPortStr = '\0'; // as the end of target address (and port)
bindAddrPortStr += 6; // strlen(",bind=")
int result = uart_parse_address_port(bindAddrPortStr, &bindAddrStr, &bindPortStr, &isBindingIPv6);
if (result != PM3_SUCCESS) {
if (result == PM3_ESOFT) {
PrintAndLogEx(ERR, "error: wrong address: [] unmatched in bind option");
} else {
PrintAndLogEx(ERR, "error: failed to parse address and port in bind option");
}
free(addrPortStr);
free(sp);
return INVALID_SERIAL_PORT;
}
// for bind option, it's possible to only specify address or port
if (strlen(bindAddrStr) == 0)
bindAddrStr = NULL;
if (bindPortStr != NULL && strlen(bindPortStr) == 0)
bindPortStr = NULL;
}
const char *addrStr = NULL;
const char *portStr = NULL;
bool isIPv6 = false;
int result = uart_parse_address_port(addrPortStr, &addrStr, &portStr, &isIPv6);
if (result != PM3_SUCCESS) {
if (result == PM3_ESOFT) {
PrintAndLogEx(ERR, "error: wrong address: [] unmatched");
} else {
PrintAndLogEx(ERR, "error: failed to parse address and port");
}
free(addrPortStr);
free(sp);
return INVALID_SERIAL_PORT;
}
g_conn.send_via_ip = isIPv6 ? PM3_UDPv6 : PM3_UDPv4;
portStr = (portStr == NULL) ? "18888" : portStr;
struct addrinfo info;
memset(&info, 0, sizeof(info));
info.ai_family = PF_UNSPEC;
info.ai_socktype = SOCK_DGRAM;
if ((strstr(addrStr, "localhost") != NULL) ||
(strstr(addrStr, "127.0.0.1") != NULL) ||
(strstr(addrStr, "::1") != NULL)) {
g_conn.send_via_local_ip = true;
}
int s = getaddrinfo(addrStr, portStr, &info, &addr);
if (s != 0) {
PrintAndLogEx(ERR, "error: getaddrinfo: %s", gai_strerror(s));
freeaddrinfo(addr);
free(addrPortStr);
free(sp);
return INVALID_SERIAL_PORT;
}
int sfd;
for (rp = addr; rp != NULL; rp = rp->ai_next) {
sfd = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol);
if (sfd == -1)
continue;
if (!uart_bind(&sfd, bindAddrStr, bindPortStr, isBindingIPv6)) {
PrintAndLogEx(ERR, "error: Could not bind. errno: %d", errno);
close(sfd);
freeaddrinfo(addr);
free(addrPortStr);
free(sp);
return INVALID_SERIAL_PORT;
}
if (connect(sfd, rp->ai_addr, rp->ai_addrlen) != -1)
break;
close(sfd);
}
freeaddrinfo(addr);
free(addrPortStr);
if (rp == NULL) { /* No address succeeded */
PrintAndLogEx(ERR, "error: Could not connect");
free(sp);
return INVALID_SERIAL_PORT;
}
sp->fd = sfd;
sp->udpBuffer = RingBuf_create(MAX(sizeof(PacketResponseNGRaw), sizeof(PacketResponseOLD)) * 30); sp->udpBuffer = RingBuf_create(MAX(sizeof(PacketResponseNGRaw), sizeof(PacketResponseOLD)) * 30);
}
return sp; return sp;
} }
if (isBluetooth) {
if (memcmp(prefix, "bt:", 3) == 0) {
free(prefix); free(prefix);
#ifdef HAVE_BLUEZ #ifdef HAVE_BLUEZ
if (strlen(pcPortName) != 20) { if (strlen(pcPortName) != 20) {
PrintAndLogEx(ERR, "Error: wrong Bluetooth MAC address length");
free(sp); free(sp);
return INVALID_SERIAL_PORT; return INVALID_SERIAL_PORT;
} }
@ -418,14 +307,9 @@ serial_port uart_open(const char *pcPortName, uint32_t speed) {
// Is local socket buffer, not a TCP or any net connection! // Is local socket buffer, not a TCP or any net connection!
// so, you can't connect with address like: 127.0.0.1, or any IP // so, you can't connect with address like: 127.0.0.1, or any IP
// see http://man7.org/linux/man-pages/man7/unix.7.html // see http://man7.org/linux/man-pages/man7/unix.7.html
if (memcmp(prefix, "socket:", 7) == 0) { if (isUnixSocket) {
free(prefix); free(prefix);
if (strlen(pcPortName) <= 7) {
free(sp);
return INVALID_SERIAL_PORT;
}
// we must use max timeout! // we must use max timeout!
timeout.tv_usec = UART_NET_CLIENT_RX_TIMEOUT_MS * 1000; timeout.tv_usec = UART_NET_CLIENT_RX_TIMEOUT_MS * 1000;

View file

@ -103,15 +103,17 @@ serial_port uart_open(const char *pcPortName, uint32_t speed) {
} }
str_lower(prefix); str_lower(prefix);
if (memcmp(prefix, "tcp:", 4) == 0) { bool isTCP = false;
free(prefix); bool isUDP = false;
if (strlen(prefix) > 4) {
if (strlen(pcPortName) <= 4) { isTCP = (memcmp(prefix, "tcp:", 4) == 0);
PrintAndLogEx(ERR, "error: tcp port name length too short"); isUDP = (memcmp(prefix, "udp:", 4) == 0);
free(sp);
return INVALID_SERIAL_PORT;
} }
if (isTCP || isUDP) {
free(prefix);
struct addrinfo *addr = NULL, *rp; struct addrinfo *addr = NULL, *rp;
char *addrPortStr = str_dup(pcPortName + 4); char *addrPortStr = str_dup(pcPortName + 4);
@ -168,7 +170,7 @@ serial_port uart_open(const char *pcPortName, uint32_t speed) {
return INVALID_SERIAL_PORT; return INVALID_SERIAL_PORT;
} }
g_conn.send_via_ip = isIPv6 ? PM3_TCPv6 : PM3_TCPv4; g_conn.send_via_ip = isIPv6 ? (isTCP ? PM3_TCPv6 : PM3_UDPv6) : (isTCP ? PM3_TCPv4 : PM3_UDPv4);
portStr = (portStr == NULL) ? "18888" : portStr; portStr = (portStr == NULL) ? "18888" : portStr;
WSADATA wsaData; WSADATA wsaData;
@ -185,8 +187,8 @@ serial_port uart_open(const char *pcPortName, uint32_t speed) {
memset(&info, 0, sizeof(info)); memset(&info, 0, sizeof(info));
info.ai_family = AF_UNSPEC; info.ai_family = AF_UNSPEC;
info.ai_socktype = SOCK_STREAM; info.ai_socktype = isTCP ? SOCK_STREAM : SOCK_DGRAM;
info.ai_protocol = IPPROTO_TCP; info.ai_protocol = isTCP ? IPPROTO_TCP : IPPROTO_UDP;
if ((strstr(addrStr, "localhost") != NULL) || if ((strstr(addrStr, "localhost") != NULL) ||
(strstr(addrStr, "127.0.0.1") != NULL) || (strstr(addrStr, "127.0.0.1") != NULL) ||
@ -241,6 +243,7 @@ serial_port uart_open(const char *pcPortName, uint32_t speed) {
sp->hSocket = hSocket; sp->hSocket = hSocket;
if (isTCP) {
int one = 1; int one = 1;
int res = setsockopt(sp->hSocket, IPPROTO_TCP, TCP_NODELAY, (char *)&one, sizeof(one)); int res = setsockopt(sp->hSocket, IPPROTO_TCP, TCP_NODELAY, (char *)&one, sizeof(one));
if (res != 0) { if (res != 0) {
@ -249,147 +252,10 @@ serial_port uart_open(const char *pcPortName, uint32_t speed) {
free(sp); free(sp);
return INVALID_SERIAL_PORT; return INVALID_SERIAL_PORT;
} }
return sp; } else if (isUDP) {
}
if (memcmp(prefix, "udp:", 4) == 0) {
free(prefix);
if (strlen(pcPortName) <= 4) {
PrintAndLogEx(ERR, "error: udp port name length too short");
free(sp);
return INVALID_SERIAL_PORT;
}
struct addrinfo *addr = NULL, *rp;
char *addrPortStr = str_dup(pcPortName + 4);
if (addrPortStr == NULL) {
PrintAndLogEx(ERR, "error: string duplication");
free(sp);
return INVALID_SERIAL_PORT;
}
timeout.tv_usec = UART_NET_CLIENT_RX_TIMEOUT_MS * 1000;
// find the "bind" option
char *bindAddrPortStr = strstr(addrPortStr, ",bind=");
const char *bindAddrStr = NULL;
const char *bindPortStr = NULL;
bool isBindingIPv6 = false;
if (bindAddrPortStr != NULL) {
*bindAddrPortStr = '\0'; // as the end of target address (and port)
bindAddrPortStr += 6; // strlen(",bind=")
int result = uart_parse_address_port(bindAddrPortStr, &bindAddrStr, &bindPortStr, &isBindingIPv6);
if (result != PM3_SUCCESS) {
if (result == PM3_ESOFT) {
PrintAndLogEx(ERR, "error: wrong address: [] unmatched in bind option");
} else {
PrintAndLogEx(ERR, "error: failed to parse address and port in bind option");
}
free(addrPortStr);
free(sp);
return INVALID_SERIAL_PORT;
}
// for bind option, it's possible to only specify address or port
if (strlen(bindAddrStr) == 0)
bindAddrStr = NULL;
if (bindPortStr != NULL && strlen(bindPortStr) == 0)
bindPortStr = NULL;
}
const char *addrStr = NULL;
const char *portStr = NULL;
bool isIPv6 = false;
int result = uart_parse_address_port(addrPortStr, &addrStr, &portStr, &isIPv6);
if (result != PM3_SUCCESS) {
if (result == PM3_ESOFT) {
PrintAndLogEx(ERR, "error: wrong address: [] unmatched");
} else {
PrintAndLogEx(ERR, "error: failed to parse address and port");
}
free(addrPortStr);
free(sp);
return INVALID_SERIAL_PORT;
}
g_conn.send_via_ip = isIPv6 ? PM3_UDPv6 : PM3_UDPv4;
portStr = (portStr == NULL) ? "18888" : portStr;
WSADATA wsaData;
struct addrinfo info;
int iResult;
iResult = WSAStartup(MAKEWORD(2, 2), &wsaData);
if (iResult != 0) {
PrintAndLogEx(ERR, "error: WSAStartup failed with error: %d", iResult);
free(addrPortStr);
free(sp);
return INVALID_SERIAL_PORT;
}
memset(&info, 0, sizeof(info));
info.ai_family = AF_UNSPEC;
info.ai_socktype = SOCK_DGRAM;
info.ai_protocol = IPPROTO_UDP;
if ((strstr(addrStr, "localhost") != NULL) ||
(strstr(addrStr, "127.0.0.1") != NULL) ||
(strstr(addrStr, "::1") != NULL)) {
g_conn.send_via_local_ip = true;
}
int s = getaddrinfo(addrStr, portStr, &info, &addr);
if (s != 0) {
PrintAndLogEx(ERR, "error: getaddrinfo: %d: %s", s, gai_strerror(s));
freeaddrinfo(addr);
free(addrPortStr);
free(sp);
WSACleanup();
return INVALID_SERIAL_PORT;
}
SOCKET hSocket = INVALID_SOCKET;
for (rp = addr; rp != NULL; rp = rp->ai_next) {
hSocket = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol);
if (hSocket == INVALID_SOCKET)
continue;
if (!uart_bind(&hSocket, bindAddrStr, bindPortStr, isBindingIPv6)) {
PrintAndLogEx(ERR, "error: Could not bind. error: %u", WSAGetLastError());
closesocket(hSocket);
hSocket = INVALID_SOCKET;
freeaddrinfo(addr);
free(addrPortStr);
free(sp);
WSACleanup();
return INVALID_SERIAL_PORT;
}
if (connect(hSocket, rp->ai_addr, (int)rp->ai_addrlen) != INVALID_SOCKET)
break;
closesocket(hSocket);
hSocket = INVALID_SOCKET;
}
freeaddrinfo(addr);
free(addrPortStr);
if (rp == NULL) { /* No address succeeded */
PrintAndLogEx(ERR, "error: Could not connect");
WSACleanup();
free(sp);
return INVALID_SERIAL_PORT;
}
sp->hSocket = hSocket;
sp->udpBuffer = RingBuf_create(MAX(sizeof(PacketResponseNGRaw), sizeof(PacketResponseOLD)) * 30); sp->udpBuffer = RingBuf_create(MAX(sizeof(PacketResponseNGRaw), sizeof(PacketResponseOLD)) * 30);
}
return sp; return sp;
} }