Replace NEKit dependency with reduced subset

This commit is contained in:
relikd
2020-03-24 21:12:58 +01:00
parent 2473e77519
commit cbec3981bb
103 changed files with 24996 additions and 264 deletions

View File

@@ -0,0 +1,18 @@
import Foundation
public enum DNSType: UInt16 {
// swiftlint:disable:next type_name
case invalid = 0, a, ns, md, mf, cname, soa, mb, mg, mr, null, wks, ptr, hinfo, minfo, mx, txt, rp, afsdb, x25, isdn, rt, nsap, nsapptr, sig, key, px, gpos, aaaa, loc, nxt, eid, nimloc, srv, atma, naptr, kx, cert, a6, dname, sink, opt, apl, ds, sshfp, rrsig = 46, nsec, dnskey, tkey = 249, tsig, ixfr, axfr, mailb, maila, any
}
public enum DNSMessageType: UInt8 {
case query, response
}
public enum DNSReturnStatus: UInt8 {
case success = 0, formatError, serverFailure, nameError, notImplemented, refused
}
public enum DNSClass: UInt16 {
case internet = 1
}

View File

@@ -0,0 +1,389 @@
import Foundation
open class DNSMessage {
// var sourceAddress: IPv4Address?
// var sourcePort: Port?
// var destinationAddress: IPv4Address?
// var destinationPort: Port?
open var transactionID: UInt16 = 0
open var messageType: DNSMessageType = .query
open var authoritative: Bool = false
open var truncation: Bool = false
open var recursionDesired: Bool = false
open var recursionAvailable: Bool = false
open var status: DNSReturnStatus = .success
open var queries: [DNSQuery] = []
open var answers: [DNSResource] = []
open var nameservers: [DNSResource] = []
open var addtionals: [DNSResource] = []
var payload: Data!
var bytesLength: Int {
var len = 12 + queries.reduce(0) {
$0 + $1.bytesLength
}
len += answers.reduce(0) {
$0 + $1.bytesLength
}
len += nameservers.reduce(0) {
$0 + $1.bytesLength
}
len += addtionals.reduce(0) {
$0 + $1.bytesLength
}
return len
}
var resolvedIPv4Address: IPAddress? {
for answer in answers {
if let address = answer.ipv4Address {
return address
}
}
return nil
}
var type: DNSType? {
return queries.first?.type
}
init() {}
init?(payload: Data) {
self.payload = payload
let scanner = BinaryDataScanner(data: payload, littleEndian: false)
transactionID = scanner.read16()!
var bytes = scanner.readByte()!
if bytes & 0x80 > 0 {
messageType = .response
} else {
messageType = .query
}
// ignore OP code
authoritative = bytes & 0x04 > 0
truncation = bytes & 0x02 > 0
recursionDesired = bytes & 0x01 > 0
bytes = scanner.readByte()!
recursionAvailable = bytes & 0x80 > 0
if let status = DNSReturnStatus(rawValue: bytes & 0x0F) {
self.status = status
} else {
DDLogError("Received DNS response with unknown status: \(bytes & 0x0F).")
self.status = .serverFailure
}
let queryCount = scanner.read16()!
let answerCount = scanner.read16()!
let nameserverCount = scanner.read16()!
let addtionalCount = scanner.read16()!
for _ in 0..<queryCount {
queries.append(DNSQuery(payload: payload, offset: scanner.position, base: 0)!)
scanner.advance(by: queries.last!.bytesLength)
}
for _ in 0..<answerCount {
answers.append(DNSResource(payload: payload, offset: scanner.position, base: 0)!)
scanner.advance(by: answers.last!.bytesLength)
}
for _ in 0..<nameserverCount {
nameservers.append(DNSResource(payload: payload, offset: scanner.position, base: 0)!)
scanner.advance(by: nameservers.last!.bytesLength)
}
for _ in 0..<addtionalCount {
addtionals.append(DNSResource(payload: payload, offset: scanner.position, base: 0)!)
scanner.advance(by: addtionals.last!.bytesLength)
}
}
func buildMessage() -> Bool {
payload = Data(count: bytesLength)
if transactionID == 0 {
transactionID = UInt16(arc4random_uniform(UInt32(UInt16.max)))
}
setPayloadWithUInt16(transactionID, at: 0, swap: true)
var byte: UInt8 = 0
byte += messageType.rawValue << 7
if authoritative {
byte += 4
}
if truncation {
byte += 2
}
if recursionDesired {
byte += 1
}
setPayloadWithUInt8(byte, at: 2)
byte = 0
if recursionAvailable {
byte += 128
}
byte += status.rawValue
setPayloadWithUInt8(byte, at: 3)
setPayloadWithUInt16(UInt16(queries.count), at: 4, swap: true)
setPayloadWithUInt16(UInt16(answers.count), at: 6, swap: true)
setPayloadWithUInt16(UInt16(nameservers.count), at: 8, swap: true)
setPayloadWithUInt16(UInt16(addtionals.count), at: 10, swap: true)
return writeAllRecordAt(12)
}
// swiftlint:disable variable_name
func setPayloadWithUInt8(_ value: UInt8, at: Int) {
var v = value
withUnsafeBytes(of: &v) {
payload.replaceSubrange(at..<at+1, with: $0)
}
}
func setPayloadWithUInt16(_ value: UInt16, at: Int, swap: Bool = false) {
var v: UInt16
if swap {
v = NSSwapHostShortToBig(value)
} else {
v = value
}
withUnsafeBytes(of: &v) {
payload.replaceSubrange(at..<at+2, with: $0)
}
}
func setPayloadWithUInt32(_ value: UInt32, at: Int, swap: Bool = false) {
var v: UInt32
if swap {
v = NSSwapHostIntToBig(value)
} else {
v = value
}
withUnsafeBytes(of: &v) {
payload.replaceSubrange(at..<at+4, with: $0)
}
}
func setPayloadWithData(_ data: Data, at: Int, length: Int? = nil, from: Int = 0) {
let length = length ?? data.count - from
payload.withUnsafeMutableBytes { ptr in
data.copyBytes(to: ptr.baseAddress!.advanced(by: at).assumingMemoryBound(to: UInt8.self), from: from..<from+length)
}
}
func resetPayloadAt(_ at: Int, length: Int) {
payload.resetBytes(in: at..<at+length)
}
fileprivate func writeAllRecordAt(_ at: Int) -> Bool {
var position = at
for query in queries {
guard writeDNSQuery(query, at: position) else {
return false
}
position += query.bytesLength
}
for resources in [answers, nameservers, addtionals] {
for resource in resources {
guard writeDNSResource(resource, at: position) else {
return false
}
position += resource.bytesLength
}
}
return true
}
fileprivate func writeDNSQuery(_ query: DNSQuery, at: Int) -> Bool {
guard DNSNameConverter.setName(query.name, toData: &payload!, at: at) else {
return false
}
setPayloadWithUInt16(query.type.rawValue, at: at + query.nameBytesLength, swap: true)
setPayloadWithUInt16(query.klass.rawValue, at: at + query.nameBytesLength + 2, swap: true)
return true
}
fileprivate func writeDNSResource(_ resource: DNSResource, at: Int) -> Bool {
guard DNSNameConverter.setName(resource.name, toData: &payload!, at: at) else {
return false
}
setPayloadWithUInt16(resource.type.rawValue, at: at + resource.nameBytesLength, swap: true)
setPayloadWithUInt16(resource.klass.rawValue, at: at + resource.nameBytesLength + 2, swap: true)
setPayloadWithUInt32(resource.TTL, at: at + resource.nameBytesLength + 4, swap: true)
setPayloadWithUInt16(resource.dataLength, at: at + resource.nameBytesLength + 8, swap: true)
setPayloadWithData(resource.data, at: at + resource.nameBytesLength + 10)
return true
}
}
open class DNSQuery {
public let name: String
public let type: DNSType
public let klass: DNSClass
let nameBytesLength: Int
init(name: String, type: DNSType = .a, klass: DNSClass = .internet) {
self.name = name.trimmingCharacters(in: CharacterSet(charactersIn: "."))
self.type = type
self.klass = klass
self.nameBytesLength = name.utf8.count + 2
}
init?(payload: Data, offset: Int, base: Int = 0) {
(self.name, self.nameBytesLength) = DNSNameConverter.getNamefromData(payload, offset: offset, base: base)
let scanner = BinaryDataScanner(data: payload, littleEndian: false)
scanner.skip(to: offset + self.nameBytesLength)
guard let type = DNSType(rawValue: scanner.read16()!) else {
DDLogError("Received DNS packet with unknown type.")
return nil
}
self.type = type
guard let klass = DNSClass(rawValue: scanner.read16()!) else {
DDLogError("Received DNS packet with unknown class.")
return nil
}
self.klass = klass
}
var bytesLength: Int {
return nameBytesLength + 4
}
}
open class DNSResource {
public let name: String
public let type: DNSType
public let klass: DNSClass
public let TTL: UInt32
let dataLength: UInt16
public let data: Data
let nameBytesLength: Int
init(name: String, type: DNSType = .a, klass: DNSClass = .internet, TTL: UInt32 = 300, data: Data) {
self.name = name
self.type = type
self.klass = klass
self.TTL = TTL
dataLength = UInt16(data.count)
self.data = data
self.nameBytesLength = name.utf8.count + 2
}
static func ARecord(_ name: String, TTL: UInt32 = 300, address: IPAddress) -> DNSResource {
return DNSResource(name: name, type: .a, klass: .internet, TTL: TTL, data: address.dataInNetworkOrder)
}
init?(payload: Data, offset: Int, base: Int = 0) {
(self.name, self.nameBytesLength) = DNSNameConverter.getNamefromData(payload, offset: offset, base: base)
let scanner = BinaryDataScanner(data: payload, littleEndian: false)
scanner.skip(to: offset + self.nameBytesLength)
guard let type = DNSType(rawValue: scanner.read16()!) else {
DDLogError("Received DNS packet with unknown type.")
return nil
}
self.type = type
guard let klass = DNSClass(rawValue: scanner.read16()!) else {
DDLogError("Received DNS packet with unknown class.")
return nil
}
self.klass = klass
self.TTL = scanner.read32()!
dataLength = scanner.read16()!
self.data = payload.subdata(in: scanner.position..<scanner.position+Int(dataLength))
}
var bytesLength: Int {
return nameBytesLength + 10 + Int(dataLength)
}
var ipv4Address: IPAddress? {
guard type == .a else {
return nil
}
return IPAddress(fromBytesInNetworkOrder: (data as NSData).bytes)
}
}
class DNSNameConverter {
static func setName(_ name: String, toData data: inout Data, at: Int) -> Bool {
let labels = name.components(separatedBy: CharacterSet(charactersIn: "."))
var position = at
for label in labels {
let len = label.utf8.count
guard len != 0 else {
// invalid domain name
return false
}
data[position] = UInt8(len)
position += 1
data.replaceSubrange(position..<position+len, with: label.data(using: .utf8)!)
position += len
}
data[position] = 0
return true
}
static func getNamefromData(_ data: Data, offset: Int, base: Int = 0) -> (String, Int) {
let scanner = BinaryDataScanner(data: data, littleEndian: false)
scanner.skip(to: offset)
var len: UInt8 = 0
var name = ""
var currentReadBytes = 0
var jumped = false
var nameBytesLength = 0
repeat {
let length = scanner.read16()!
// is this a pointer?
if length & 0xC000 == 0xC000 {
if !jumped {
// save the length position
nameBytesLength = 2 + currentReadBytes
jumped = true
}
scanner.skip(to: Int(length & 0x3FFF) + base)
} else {
scanner.advance(by: -2)
}
len = scanner.readByte()!
currentReadBytes += 1
if len == 0 {
break
}
currentReadBytes += Int(len)
guard let label = String(bytes: scanner.data.subdata(in: scanner.position..<scanner.position+Int(len)), encoding: .utf8) else {
return ("", currentReadBytes)
}
// this is not efficient, but won't take much time, so maybe I'll optimize it later
name = name.appendingFormat(".%@", label)
scanner.advance(by: Int(len))
} while true
if !jumped {
nameBytesLength = currentReadBytes
}
return (name.trimmingCharacters(in: CharacterSet(charactersIn: ".")), nameBytesLength)
}
}

View File

@@ -0,0 +1,37 @@
import Foundation
public protocol DNSResolverProtocol: class {
var delegate: DNSResolverDelegate? { get set }
func resolve(session: DNSSession)
func stop()
}
public protocol DNSResolverDelegate: class {
func didReceive(rawResponse: Data)
}
open class UDPDNSResolver: DNSResolverProtocol, NWUDPSocketDelegate {
let socket: NWUDPSocket
public weak var delegate: DNSResolverDelegate?
public init(address: IPAddress, port: Port) {
socket = NWUDPSocket(host: address.presentation, port: Int(port.value))!
socket.delegate = self
}
public func resolve(session: DNSSession) {
socket.write(data: session.requestMessage.payload)
}
public func stop() {
socket.disconnect()
}
public func didReceive(data: Data, from: NWUDPSocket) {
delegate?.didReceive(rawResponse: data)
}
public func didCancel(socket: NWUDPSocket) {
}
}

View File

@@ -0,0 +1,269 @@
import Foundation
import NetworkExtension
/// A DNS server designed as an `IPStackProtocol` implementation which works with TUN interface.
///
/// This class is thread-safe.
open class DNSServer: DNSResolverDelegate, IPStackProtocol {
/// Current DNS server.
///
/// - warning: There is at most one DNS server running at the same time. If a DNS server is registered to `TUNInterface` then it must also be set here.
public static var currentServer: DNSServer?
/// The address of DNS server.
let serverAddress: IPAddress
/// The port of DNS server
let serverPort: Port
fileprivate let queue: DispatchQueue = QueueFactory.getQueue()
fileprivate var fakeSessions: [IPAddress: DNSSession] = [:]
fileprivate var pendingSessions: [UInt16: DNSSession] = [:]
fileprivate let pool: IPPool?
fileprivate var resolvers: [DNSResolverProtocol] = []
open var outputFunc: (([Data], [NSNumber]) -> Void)!
// Only match A record as of now, all other records should be passed directly.
fileprivate let matchedType = [DNSType.a]
/**
Initailize a DNS server.
- parameter address: The IP address of the server.
- parameter port: The listening port of the server.
- parameter fakeIPPool: The pool of fake IP addresses. Set to nil if no fake IP is needed.
*/
public init(address: IPAddress, port: Port, fakeIPPool: IPPool? = nil) {
serverAddress = address
serverPort = port
pool = fakeIPPool
}
/**
Clean up fake IP.
- parameter address: The fake IP address.
- parameter delay: How long should the fake IP be valid.
*/
fileprivate func cleanUpFakeIP(_ address: IPAddress, after delay: Int) {
queue.asyncAfter(deadline: DispatchTime.now() + Double(Int64(delay) * Int64(NSEC_PER_SEC)) / Double(NSEC_PER_SEC)) {
[weak self] in
_ = self?.fakeSessions.removeValue(forKey: address)
self?.pool?.release(ip: address)
}
}
/**
Clean up pending session.
- parameter session: The pending session.
- parameter delay: How long before the pending session be cleaned up.
*/
fileprivate func cleanUpPendingSession(_ session: DNSSession, after delay: Int) {
queue.asyncAfter(deadline: DispatchTime.now() + Double(Int64(delay) * Int64(NSEC_PER_SEC)) / Double(NSEC_PER_SEC)) {
[weak self] in
_ = self?.pendingSessions.removeValue(forKey: session.requestMessage.transactionID)
}
}
fileprivate func lookup(_ session: DNSSession) {
guard shouldMatch(session) else {
session.matchResult = .real
lookupRemotely(session)
return
}
RuleManager.currentManager.matchDNS(session, type: .domain)
switch session.matchResult! {
case .fake:
guard setUpFakeIP(session) else {
// failed to set up a fake IP, return the result directly
session.matchResult = .real
lookupRemotely(session)
return
}
outputSession(session)
case .real, .unknown:
lookupRemotely(session)
default:
DDLogError("The rule match result should never be .Pass.")
}
}
fileprivate func lookupRemotely(_ session: DNSSession) {
pendingSessions[session.requestMessage.transactionID] = session
cleanUpPendingSession(session, after: Opt.DNSPendingSessionLifeTime)
sendQueryToRemote(session)
}
fileprivate func sendQueryToRemote(_ session: DNSSession) {
for resolver in resolvers {
resolver.resolve(session: session)
}
}
/**
Input IP packet into the DNS server.
- parameter packet: The IP packet.
- parameter version: The version of the IP packet.
- returns: If the packet is taken in by this DNS server.
*/
open func input(packet: Data, version: NSNumber?) -> Bool {
guard IPPacket.peekProtocol(packet) == .udp else {
return false
}
guard IPPacket.peekDestinationAddress(packet) == serverAddress else {
return false
}
guard IPPacket.peekDestinationPort(packet) == serverPort else {
return false
}
guard let ipPacket = IPPacket(packetData: packet) else {
return false
}
guard let session = DNSSession(packet: ipPacket) else {
return false
}
queue.async {
self.lookup(session)
}
return true
}
public func start() {
}
open func stop() {
for resolver in resolvers {
resolver.stop()
}
resolvers = []
// The blocks scheduled with `dispatch_after` are ignored since they are hard to cancel. But there should be no consequence, everything will be released except for a few `IPAddress`es and the `queue` which will be released later.
}
fileprivate func outputSession(_ session: DNSSession) {
guard let result = session.matchResult else {
return
}
let udpParser = UDPProtocolParser()
udpParser.sourcePort = serverPort
// swiftlint:disable:next force_cast
udpParser.destinationPort = (session.requestIPPacket!.protocolParser as! UDPProtocolParser).sourcePort
switch result {
case .real:
udpParser.payload = session.realResponseMessage!.payload
case .fake:
let response = DNSMessage()
response.transactionID = session.requestMessage.transactionID
response.messageType = .response
response.recursionAvailable = true
// since we only support ipv4 as of now, it must be an answer of type A
response.answers.append(DNSResource.ARecord(session.requestMessage.queries[0].name, TTL: UInt32(Opt.DNSFakeIPTTL), address: session.fakeIP!))
session.expireAt = Date().addingTimeInterval(Double(Opt.DNSFakeIPTTL))
guard response.buildMessage() else {
DDLogError("Failed to build DNS response.")
return
}
udpParser.payload = response.payload
default:
return
}
let ipPacket = IPPacket()
ipPacket.sourceAddress = serverAddress
ipPacket.destinationAddress = session.requestIPPacket!.sourceAddress
ipPacket.protocolParser = udpParser
ipPacket.transportProtocol = .udp
ipPacket.buildPacket()
outputFunc([ipPacket.packetData], [NSNumber(value: AF_INET as Int32)])
}
fileprivate func shouldMatch(_ session: DNSSession) -> Bool {
return matchedType.contains(session.requestMessage.type!)
}
func isFakeIP(_ ipAddress: IPAddress) -> Bool {
return pool?.contains(ip: ipAddress) ?? false
}
func lookupFakeIP(_ address: IPAddress) -> DNSSession? {
var session: DNSSession?
QueueFactory.executeOnQueueSynchronizedly {
session = self.fakeSessions[address]
}
return session
}
/**
Add new DNS resolver to DNS server.
- parameter resolver: The resolver to add.
*/
open func registerResolver(_ resolver: DNSResolverProtocol) {
resolver.delegate = self
resolvers.append(resolver)
}
fileprivate func setUpFakeIP(_ session: DNSSession) -> Bool {
guard let fakeIP = pool?.fetchIP() else {
DDLogVerbose("Failed to get a fake IP.")
return false
}
session.fakeIP = fakeIP
fakeSessions[fakeIP] = session
session.expireAt = Date().addingTimeInterval(TimeInterval(Opt.DNSFakeIPTTL))
// keep the fake session for 2 TTL
cleanUpFakeIP(fakeIP, after: Opt.DNSFakeIPTTL * 2)
return true
}
open func didReceive(rawResponse: Data) {
guard let message = DNSMessage(payload: rawResponse) else {
DDLogError("Failed to parse response from remote DNS server.")
return
}
queue.async {
guard let session = self.pendingSessions.removeValue(forKey: message.transactionID) else {
// this should not be a problem if there are multiple DNS servers or the DNS server is hijacked.
DDLogVerbose("Do not find the corresponding DNS session for the response.")
return
}
session.realResponseMessage = message
session.realIP = message.resolvedIPv4Address
if session.matchResult != .fake && session.matchResult != .real {
RuleManager.currentManager.matchDNS(session, type: .ip)
}
switch session.matchResult! {
case .fake:
if !self.setUpFakeIP(session) {
// return real response
session.matchResult = .real
}
self.outputSession(session)
case .real:
self.outputSession(session)
default:
DDLogError("The rule match result should never be .Pass or .Unknown in IP mode.")
}
}
}
}

View File

@@ -0,0 +1,49 @@
import Foundation
open class DNSSession {
public let requestMessage: DNSMessage
var requestIPPacket: IPPacket?
open var realIP: IPAddress?
open var fakeIP: IPAddress?
open var realResponseMessage: DNSMessage?
var realResponseIPPacket: IPPacket?
open var matchedRule: Rule?
open var matchResult: DNSSessionMatchResult?
var indexToMatch = 0
var expireAt: Date?
// lazy var countryCode: String? = {
// [unowned self] in
// guard self.realIP != nil else {
// return nil
// }
// return Utils.GeoIPLookup.Lookup(self.realIP!.presentation)
// }()
init?(message: DNSMessage) {
guard message.messageType == .query else {
DDLogError("DNSSession can only be initailized by a DNS query.")
return nil
}
guard message.queries.count == 1 else {
DDLogError("Expecting the DNS query has exact one query entry.")
return nil
}
requestMessage = message
}
convenience init?(packet: IPPacket) {
guard let message = DNSMessage(payload: packet.protocolParser.payload) else {
return nil
}
self.init(message: message)
requestIPPacket = packet
}
}
extension DNSSession: CustomStringConvertible {
public var description: String {
return "<\(type(of: self)) domain: \(self.requestMessage.queries.first!.name) realIP: \(String(describing: realIP)) fakeIP: \(String(describing: fakeIP))>"
}
}