net/nat: Optimize port selection

And fix possibly dead loop.

Signed-off-by: Zhe Weng <wengzhe@xiaomi.com>
This commit is contained in:
Zhe Weng 2024-04-11 18:28:00 +08:00 committed by Xiang Xiao
parent f3b34c84c2
commit 4eddf84a76

View File

@ -35,6 +35,22 @@
#ifdef CONFIG_NET_NAT #ifdef CONFIG_NET_NAT
/****************************************************************************
* Pre-processor Definitions
****************************************************************************/
#define NEXT_PORT(nport, hport) \
do \
{ \
++(hport); \
if ((hport) >= CONFIG_NET_DEFAULT_MAX_PORT || \
(hport) < CONFIG_NET_DEFAULT_MIN_PORT) \
{ \
(hport) = CONFIG_NET_DEFAULT_MIN_PORT; \
} \
(nport) = HTONS(hport); \
} while (0)
/**************************************************************************** /****************************************************************************
* Private Functions * Private Functions
****************************************************************************/ ****************************************************************************/
@ -50,7 +66,7 @@
* domain - The domain of the packet. * domain - The domain of the packet.
* protocol - The L4 protocol of the packet. * protocol - The L4 protocol of the packet.
* ip - The IP bind with the port (in network byte order). * ip - The IP bind with the port (in network byte order).
* portno - The local port (in network byte order), as reference. * local_port - The local port (in network byte order), as reference.
* *
* Returned Value: * Returned Value:
* port number on success; 0 on failure * port number on success; 0 on failure
@ -64,19 +80,19 @@
static uint16_t nat_port_select_without_stack( static uint16_t nat_port_select_without_stack(
uint8_t domain, uint8_t protocol, FAR const union ip_addr_u *ip, uint8_t domain, uint8_t protocol, FAR const union ip_addr_u *ip,
uint16_t portno) uint16_t local_port)
{ {
uint16_t portno = local_port;
uint16_t hport = NTOHS(portno); uint16_t hport = NTOHS(portno);
while (nat_port_inuse(domain, protocol, ip, portno)) while (nat_port_inuse(domain, protocol, ip, portno))
{ {
++hport; NEXT_PORT(portno, hport);
if (hport >= CONFIG_NET_DEFAULT_MAX_PORT || if (portno == local_port)
hport < CONFIG_NET_DEFAULT_MIN_PORT)
{ {
hport = CONFIG_NET_DEFAULT_MIN_PORT; /* We have looped back, failed. */
}
portno = HTONS(hport); return 0;
}
} }
return portno; return portno;
@ -292,14 +308,13 @@ uint16_t nat_port_select(FAR struct net_driver_s *dev,
while (icmp_findconn(dev, id) || while (icmp_findconn(dev, id) ||
nat_port_inuse(domain, IP_PROTO_ICMP, external_ip, id)) nat_port_inuse(domain, IP_PROTO_ICMP, external_ip, id))
{ {
++hid; NEXT_PORT(id, hid);
if (hid >= CONFIG_NET_DEFAULT_MAX_PORT || if (id == local_port)
hid < CONFIG_NET_DEFAULT_MIN_PORT)
{ {
hid = CONFIG_NET_DEFAULT_MIN_PORT; /* We have looped back, failed. */
}
id = HTONS(hid); return 0;
}
} }
return id; return id;
@ -319,14 +334,13 @@ uint16_t nat_port_select(FAR struct net_driver_s *dev,
while (icmpv6_active(id) || while (icmpv6_active(id) ||
nat_port_inuse(domain, IP_PROTO_ICMP6, external_ip, id)) nat_port_inuse(domain, IP_PROTO_ICMP6, external_ip, id))
{ {
++hid; NEXT_PORT(id, hid);
if (hid >= CONFIG_NET_DEFAULT_MAX_PORT || if (id == local_port)
hid < CONFIG_NET_DEFAULT_MIN_PORT)
{ {
hid = CONFIG_NET_DEFAULT_MIN_PORT; /* We have looped back, failed. */
}
id = HTONS(hid); return 0;
}
} }
return id; return id;