using Microsoft.Extensions.Caching.Memory; using Microsoft.Extensions.Options; using System; using System.Collections.Generic; using System.Diagnostics; using System.Linq; using System.Net; using System.Net.Sockets; using System.Threading; using System.Threading.Tasks; namespace FastGithub.DomainResolve { /// /// IP服务 /// 域名IP关系缓存10分钟 /// IPEndPoint时延缓存5分钟 /// IPEndPoint连接超时5秒 /// sealed class IPAddressService { private record DomainAddress(string Domain, IPAddress Address); private readonly TimeSpan domainAddressExpiration = TimeSpan.FromMinutes(10d); private readonly IMemoryCache domainAddressCache = new MemoryCache(Options.Create(new MemoryCacheOptions())); private record AddressElapsed(IPAddress Address, TimeSpan Elapsed); private readonly TimeSpan problemElapsedExpiration = TimeSpan.FromMinutes(1d); private readonly TimeSpan normalElapsedExpiration = TimeSpan.FromMinutes(5d); private readonly TimeSpan connectTimeout = TimeSpan.FromSeconds(5d); private readonly IMemoryCache addressElapsedCache = new MemoryCache(Options.Create(new MemoryCacheOptions())); private readonly DnsClient dnsClient; /// /// IP服务 /// /// public IPAddressService(DnsClient dnsClient) { this.dnsClient = dnsClient; } /// /// 并行获取可连接的IP /// /// /// /// /// public async Task GetAddressesAsync(DnsEndPoint dnsEndPoint, IEnumerable oldAddresses, CancellationToken cancellationToken) { var ipEndPoints = new HashSet(); // 历史未过期的IP节点 foreach (var address in oldAddresses) { var domainAddress = new DomainAddress(dnsEndPoint.Host, address); if (this.domainAddressCache.TryGetValue(domainAddress, out _)) { ipEndPoints.Add(new IPEndPoint(address, dnsEndPoint.Port)); } } // 新解析出的IP节点 await foreach (var address in this.dnsClient.ResolveAsync(dnsEndPoint, fastSort: false, cancellationToken)) { ipEndPoints.Add(new IPEndPoint(address, dnsEndPoint.Port)); var domainAddress = new DomainAddress(dnsEndPoint.Host, address); this.domainAddressCache.Set(domainAddress, default(object), this.domainAddressExpiration); } if (ipEndPoints.Count == 0) { return Array.Empty(); } var addressElapsedTasks = ipEndPoints.Select(item => this.GetAddressElapsedAsync(item, cancellationToken)); var addressElapseds = await Task.WhenAll(addressElapsedTasks); return addressElapseds .Where(item => item.Elapsed < TimeSpan.MaxValue) .OrderBy(item => item.Elapsed) .Select(item => item.Address) .ToArray(); } /// /// 获取IP节点的时延 /// /// /// /// private async Task GetAddressElapsedAsync(IPEndPoint endPoint, CancellationToken cancellationToken) { if (this.addressElapsedCache.TryGetValue(endPoint, out var addressElapsed)) { return addressElapsed; } var stopWatch = Stopwatch.StartNew(); try { using var timeoutTokenSource = new CancellationTokenSource(this.connectTimeout); using var linkedTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, timeoutTokenSource.Token); using var socket = new Socket(endPoint.AddressFamily, SocketType.Stream, ProtocolType.Tcp); await socket.ConnectAsync(endPoint, linkedTokenSource.Token); addressElapsed = new AddressElapsed(endPoint.Address, stopWatch.Elapsed); return this.addressElapsedCache.Set(endPoint, addressElapsed, this.normalElapsedExpiration); } catch (Exception ex) { cancellationToken.ThrowIfCancellationRequested(); addressElapsed = new AddressElapsed(endPoint.Address, TimeSpan.MaxValue); var expiration = IsLocalNetworkProblem(ex) ? this.problemElapsedExpiration : this.normalElapsedExpiration; return this.addressElapsedCache.Set(endPoint, addressElapsed, expiration); } finally { stopWatch.Stop(); } } /// /// 是否为本机网络问题 /// /// /// private static bool IsLocalNetworkProblem(Exception ex) { if (ex is not SocketException socketException) { return false; } var code = socketException.SocketErrorCode; return code == SocketError.NetworkDown || code == SocketError.NetworkUnreachable; } } }