FastGithub/FastGithub.HttpServer/TcpMiddlewares/TunnelMiddleware.cs
xingyuan55 4d9d97f871 start
2022-11-16 08:01:03 +08:00

133 lines
5.0 KiB
C#

using FastGithub.Configuration;
using FastGithub.DomainResolve;
using Microsoft.AspNetCore.Connections;
using Microsoft.AspNetCore.Connections.Features;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Features;
using System;
using System.Collections.Generic;
using System.IO;
using System.IO.Pipelines;
using System.Net;
using System.Net.Sockets;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
namespace FastGithub.HttpServer.TcpMiddlewares
{
/// <summary>
/// 隧道中间件
/// </summary>
sealed class TunnelMiddleware
{
private readonly FastGithubConfig fastGithubConfig;
private readonly IDomainResolver domainResolver;
private readonly TimeSpan connectTimeout = TimeSpan.FromSeconds(10d);
/// <summary>
/// 隧道中间件
/// </summary>
/// <param name="fastGithubConfig"></param>
/// <param name="domainResolver"></param>
public TunnelMiddleware(
FastGithubConfig fastGithubConfig,
IDomainResolver domainResolver)
{
this.fastGithubConfig = fastGithubConfig;
this.domainResolver = domainResolver;
}
/// <summary>
/// 执行中间件
/// </summary>
/// <param name="next"></param>
/// <param name="context"></param>
/// <returns></returns>
public async Task InvokeAsync(ConnectionDelegate next, ConnectionContext context)
{
var proxyFeature = context.Features.Get<IHttpProxyFeature>();
if (proxyFeature == null || // 非代理
proxyFeature.ProxyProtocol != ProxyProtocol.TunnelProxy || //非隧道代理
context.Features.Get<ITlsConnectionFeature>() != null) // 经过隧道的https
{
await next(context);
}
else
{
var transport = context.Features.Get<IConnectionTransportFeature>()?.Transport;
if (transport != null)
{
var cancellationToken = context.ConnectionClosed;
using var connection = await this.CreateConnectionAsync(proxyFeature.ProxyHost, cancellationToken);
var task1 = connection.CopyToAsync(transport.Output, cancellationToken);
var task2 = transport.Input.CopyToAsync(connection, cancellationToken);
await Task.WhenAny(task1, task2);
}
}
}
/// <summary>
/// 创建连接
/// </summary>
/// <param name="host"></param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
/// <exception cref="AggregateException"></exception>
private async Task<Stream> CreateConnectionAsync(HostString host, CancellationToken cancellationToken)
{
var innerExceptions = new List<Exception>();
await foreach (var endPoint in this.GetUpstreamEndPointsAsync(host, cancellationToken))
{
var socket = new Socket(SocketType.Stream, ProtocolType.Tcp);
try
{
using var timeoutTokenSource = new CancellationTokenSource(this.connectTimeout);
using var linkedTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, timeoutTokenSource.Token);
await socket.ConnectAsync(endPoint, linkedTokenSource.Token);
return new NetworkStream(socket, ownsSocket: true);
}
catch (Exception ex)
{
socket.Dispose();
cancellationToken.ThrowIfCancellationRequested();
innerExceptions.Add(ex);
}
}
throw new AggregateException($"无法连接到{host}", innerExceptions);
}
/// <summary>
/// 获取目标终节点
/// </summary>
/// <param name="host"></param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
private async IAsyncEnumerable<EndPoint> GetUpstreamEndPointsAsync(HostString host, [EnumeratorCancellation] CancellationToken cancellationToken)
{
const int HTTPS_PORT = 443;
var targetHost = host.Host;
var targetPort = host.Port ?? HTTPS_PORT;
if (IPAddress.TryParse(targetHost, out var address) == true)
{
yield return new IPEndPoint(address, targetPort);
}
else if (this.fastGithubConfig.IsMatch(targetHost) == false)
{
yield return new DnsEndPoint(targetHost, targetPort);
}
else
{
var dnsEndPoint = new DnsEndPoint(targetHost, targetPort);
await foreach (var item in this.domainResolver.ResolveAsync(dnsEndPoint, cancellationToken))
{
yield return new IPEndPoint(item, targetPort);
}
}
}
}
}