# C# 一种TCP连接程序白名单的方法,仅允许指定程序连接

一种TCP连接程序白名单的方法,仅允许指定程序连接,有效阻止恶意连接

可用于主动关闭指定TCP连接

C#
//设置好你的白名单
int[] whiteProcessIds = Array.Empty<int>();

//收到连接后验证
var socket = server.Accept();
int port = (socket.RemoteEndPoint as IPEndPoint).Port;
int listenPort = (server.LocalEndPoint as IPEndPoint).Port;
if (VerifyProcess(port,listenPort,whiteProcessIds ) == false)
{
    return;
}


private bool VerifyProcess(int port,int listenPort,int[] whiteProcessIds)
{
    var connections = TcpManager.GetAllTcpConnections()
        .Where(row =>
        {
            int localPort = TcpManager.TranslatePort(row.dwLocalPort);
            int remotePort = TcpManager.TranslatePort(row.dwRemotePort);
            
            //已建立连接,符合端口要求
            return row.dwState == TcpState.Established 
            && ((localPort == listenPort && remotePort == port) || (remotePort == listenPort && localPort == port));
        })
        //不在白名单内
        .Where(c => whiteProcessIds.Contains(c.dwOwningPid) == false);
    
    //符合这些的连接,干掉
    foreach (var row in connections)
    {
        int localPort = TcpManager.TranslatePort(row.dwLocalPort);
        int remotePort = TcpManager.TranslatePort(row.dwRemotePort);
        TcpManager.Kill(new IPEndPoint(row.dwLocalAddr, localPort),new IPEndPoint(row.dwRemoteAddr, remotePort));
    }
    //有符合的,说明这条连接应该阻止
    return connections.Any() == false;
}
C#
public static class TcpManager
    {
        #region PInvoke define
        public const int TCP_TABLE_OWNER_PID_ALL = 5;

        [DllImport("iphlpapi.dll", SetLastError = true)]
        public static extern uint GetExtendedTcpTable(
            IntPtr pTcpTable, ref int dwOutBufLen, bool sort, int ipVersion, int tblClass, int reserved);

        [DllImport("iphlpapi.dll")]
        public static extern int SetTcpEntry(ref MIB_TCPROW pTcpRow);


        [StructLayout(LayoutKind.Sequential)]
        public struct MIB_TCPROW
        {
            public TcpState dwState;
            public int dwLocalAddr;
            public int dwLocalPort;
            public int dwRemoteAddr;
            public int dwRemotePort;
        }

        [StructLayout(LayoutKind.Sequential)]
        public struct MIB_TCPROW_OWNER_PID
        {
            public TcpState dwState;
            public uint dwLocalAddr;
            public int dwLocalPort;
            public uint dwRemoteAddr;
            public int dwRemotePort;
            public int dwOwningPid;
        }

        [StructLayout(LayoutKind.Sequential)]
        public struct MIB_TCPTABLE_OWNER_PID
        {
            public uint dwNumEntries;
            private MIB_TCPROW_OWNER_PID table;
        }
        #endregion

        public static MIB_TCPROW_OWNER_PID[] GetAllTcpConnections()
        {
            const int NO_ERROR = 0;
            const int IP_v4 = 2;
            MIB_TCPROW_OWNER_PID[] tTable = null;
            int buffSize = 0;
            GetExtendedTcpTable(IntPtr.Zero, ref buffSize, true, IP_v4, TCP_TABLE_OWNER_PID_ALL, 0);
            IntPtr buffTable = Marshal.AllocHGlobal(buffSize);
            try
            {
                if (NO_ERROR != GetExtendedTcpTable(buffTable, ref buffSize, true, IP_v4, TCP_TABLE_OWNER_PID_ALL, 0)) return null;
                MIB_TCPTABLE_OWNER_PID tab =
                    (MIB_TCPTABLE_OWNER_PID)Marshal.PtrToStructure(buffTable, typeof(MIB_TCPTABLE_OWNER_PID));
                IntPtr rowPtr = (IntPtr)((long)buffTable + Marshal.SizeOf(tab.dwNumEntries));
                tTable = new MIB_TCPROW_OWNER_PID[tab.dwNumEntries];

                int rowSize = Marshal.SizeOf(typeof(MIB_TCPROW_OWNER_PID));
                for (int i = 0; i < tab.dwNumEntries; i++)
                {
                    MIB_TCPROW_OWNER_PID tcpRow =
                        (MIB_TCPROW_OWNER_PID)Marshal.PtrToStructure(rowPtr, typeof(MIB_TCPROW_OWNER_PID));
                    tTable[i] = tcpRow;
                    rowPtr = (IntPtr)((int)rowPtr + rowSize);
                }
            }
            finally
            {
                Marshal.FreeHGlobal(buffTable);
            }
            return tTable;
        }

        public static int TranslatePort(int port)
        {
            return ((port & 0xFF) << 8 | (port & 0xFF00) >> 8);
        }

        public static bool Kill(IPEndPoint localEndPoint, IPEndPoint remoteEndPoint)
        {
            if (localEndPoint == null) throw new ArgumentNullException("localEndPoint");
            if (remoteEndPoint == null) throw new ArgumentNullException("remoteEndPoint");
            MIB_TCPROW row = new MIB_TCPROW();
            row.dwState = TcpState.DeleteTcb;
#pragma warning disable 612, 618
            row.dwLocalAddr = (int)localEndPoint.Address.Address;
#pragma warning restore 612, 618
            row.dwLocalPort = TranslatePort(localEndPoint.Port);
#pragma warning disable 612, 618
            row.dwRemoteAddr = (int)remoteEndPoint.Address.Address;
#pragma warning restore 612, 618
            row.dwRemotePort = TranslatePort(remoteEndPoint.Port);
            return SetTcpEntry(ref row) == 0;
        }

}