using Microsoft.Extensions.Logging;
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Linq;
using System.Net;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
namespace FastGithub.DomainResolve
{
///
/// 域名解析器
///
sealed class DomainResolver : IDomainResolver
{
private const int MAX_IP_COUNT = 3;
private readonly DnsClient dnsClient;
private readonly PersistenceService persistence;
private readonly IPAddressService addressService;
private readonly ILogger logger;
private readonly ConcurrentDictionary dnsEndPointAddress = new();
///
/// 域名解析器
///
///
///
///
///
public DomainResolver(
DnsClient dnsClient,
PersistenceService persistence,
IPAddressService addressService,
ILogger logger)
{
this.dnsClient = dnsClient;
this.persistence = persistence;
this.addressService = addressService;
this.logger = logger;
foreach (var endPoint in persistence.ReadDnsEndPoints())
{
this.dnsEndPointAddress.TryAdd(endPoint, Array.Empty());
}
}
///
/// 解析域名
///
/// 节点
///
///
public async IAsyncEnumerable ResolveAsync(DnsEndPoint endPoint, [EnumeratorCancellation] CancellationToken cancellationToken)
{
if (this.dnsEndPointAddress.TryGetValue(endPoint, out var addresses) && addresses.Length > 0)
{
foreach (var address in addresses)
{
yield return address;
}
}
else
{
if (this.dnsEndPointAddress.TryAdd(endPoint, Array.Empty()))
{
await this.persistence.WriteDnsEndPointsAsync(this.dnsEndPointAddress.Keys, cancellationToken);
}
await foreach (var adddress in this.dnsClient.ResolveAsync(endPoint, fastSort: true, cancellationToken))
{
yield return adddress;
}
}
}
///
/// 对所有节点进行测速
///
///
///
public async Task TestSpeedAsync(CancellationToken cancellationToken)
{
foreach (var keyValue in this.dnsEndPointAddress.OrderBy(item => item.Value.Length))
{
var dnsEndPoint = keyValue.Key;
var oldAddresses = keyValue.Value;
var newAddresses = await this.addressService.GetAddressesAsync(dnsEndPoint, oldAddresses, cancellationToken);
this.dnsEndPointAddress[dnsEndPoint] = newAddresses;
var oldSegmentums = oldAddresses.Take(MAX_IP_COUNT);
var newSegmentums = newAddresses.Take(MAX_IP_COUNT);
if (oldSegmentums.SequenceEqual(newSegmentums) == false)
{
var addressArray = string.Join(", ", newSegmentums.Select(item => item.ToString()));
this.logger.LogInformation($"{dnsEndPoint.Host}:{dnsEndPoint.Port}->[{addressArray}]");
}
}
}
}
}