using FastGithub.DomainResolve; using Microsoft.AspNetCore.Connections; using System; using System.Collections.Generic; using System.IO; using System.IO.Pipelines; using System.Net; using System.Net.Sockets; using System.Threading; using System.Threading.Tasks; namespace FastGithub.HttpServer.TcpMiddlewares { /// /// tcp协议代理处理者 /// abstract class TcpReverseProxyHandler : ConnectionHandler { private readonly IDomainResolver domainResolver; private readonly DnsEndPoint endPoint; private readonly TimeSpan connectTimeout = TimeSpan.FromSeconds(10d); /// /// tcp协议代理处理者 /// /// /// public TcpReverseProxyHandler(IDomainResolver domainResolver, DnsEndPoint endPoint) { this.domainResolver = domainResolver; this.endPoint = endPoint; } /// /// tcp连接后 /// /// /// public override async Task OnConnectedAsync(ConnectionContext context) { var cancellationToken = context.ConnectionClosed; using var connection = await CreateConnectionAsync(cancellationToken); var task1 = connection.CopyToAsync(context.Transport.Output, cancellationToken); var task2 = context.Transport.Input.CopyToAsync(connection, cancellationToken); await Task.WhenAny(task1, task2); } /// /// 创建连接 /// /// /// /// private async Task CreateConnectionAsync(CancellationToken cancellationToken) { var innerExceptions = new List(); await foreach (var address in domainResolver.ResolveAsync(endPoint, cancellationToken)) { var socket = new Socket(address.AddressFamily, SocketType.Stream, ProtocolType.Tcp); try { using var timeoutTokenSource = new CancellationTokenSource(connectTimeout); using var linkedTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, timeoutTokenSource.Token); await socket.ConnectAsync(address, endPoint.Port, linkedTokenSource.Token); return new NetworkStream(socket, ownsSocket: false); } catch (Exception ex) { socket.Dispose(); cancellationToken.ThrowIfCancellationRequested(); innerExceptions.Add(ex); } } throw new AggregateException($"无法连接到{endPoint.Host}:{endPoint.Port}", innerExceptions); } } }