Replace NEKit dependency with reduced subset

This commit is contained in:
relikd
2020-03-24 21:12:58 +01:00
parent 2473e77519
commit cbec3981bb
103 changed files with 24996 additions and 264 deletions

24
GlassVPN/DDLog.swift Normal file
View File

@@ -0,0 +1,24 @@
import Foundation
// MARK: Third party dependencies
//
// Removed unnecessary parts of NEKit to keep the dependency chain small.
// Omitting embeded frameworks; which aren't allowed in NetworkExtensions?
//
// 0.15.0 https://github.com/zhuhaow/NEKit/commit/f09ba8aef1e70881edf0578d23c04d88cc706f52
// 0.3.0 https://github.com/zhuhaow/Resolver/commit/5d08fd52822d1f9217019ae8867e78daa48f667c
// 7.6.4 https://github.com/robbiehanson/CocoaAsyncSocket/commit/0e00c967a010fc43ce528bd633d032f17158d393
// MARK: DDLog
#if DEBUG
@inlinable public func DDLogVerbose(_ message: String) { NSLog("[VPN.VERBOSE] " + message) }
@inlinable public func DDLogDebug(_ message: String) { NSLog("[VPN.DEBUG] " + message) }
#else
@inlinable public func DDLogVerbose(_ _: String) {}
@inlinable public func DDLogDebug(_ _: String) {}
#endif
@inlinable public func DDLogInfo(_ message: String) { NSLog("[VPN.INFO] " + message) }
@inlinable public func DDLogWarn(_ message: String) { NSLog("[VPN.WARN] " + message) }
@inlinable public func DDLogError(_ message: String) { NSLog("[VPN.ERROR] " + message) }

View File

@@ -1,5 +1,4 @@
import NetworkExtension
import NEKit
fileprivate var db: SQLiteDatabase?
fileprivate var domainFilters: [String : FilterOptions] = [:]
@@ -16,13 +15,13 @@ class LDObserverFactory: ObserverFactory {
override func signal(_ event: ProxySocketEvent) {
switch event {
case .receivedRequest(let session, let socket):
ZLog("DNS: \(session.host)")
DDLogDebug("DNS: \(session.host)")
let match = domainFilters.first { session.host == $0.key || session.host.hasSuffix("." + $0.key) }
let block = match?.value.contains(.blocked) ?? false
let ignore = match?.value.contains(.ignored) ?? false
if !ignore { try? db?.insertDNSQuery(session.host, blocked: block) }
else { ZLog("ignored") }
if block { ZLog("blocked"); socket.forceDisconnect() }
else { DDLogDebug("ignored") }
if block { DDLogDebug("blocked"); socket.forceDisconnect() }
default:
break
}
@@ -44,10 +43,10 @@ class PacketTunnelProvider: NEPacketTunnelProvider {
}
override func startTunnel(options: [String : NSObject]?, completionHandler: @escaping (Error?) -> Void) {
ZLog("startTunnel")
DDLogVerbose("startTunnel")
do {
db = try SQLiteDatabase.open()
try db!.createTable(table: DNSQuery.self)
db!.initScheme()
} catch {
completionHandler(error)
return
@@ -79,11 +78,11 @@ class PacketTunnelProvider: NEPacketTunnelProvider {
self.setTunnelNetworkSettings(settings) { error in
guard error == nil else {
ZLog("setTunnelNetworkSettings error: \(String(describing: error))")
DDLogError("setTunnelNetworkSettings error: \(String(describing: error))")
completionHandler(error)
return
}
ZLog("setTunnelNetworkSettings success \(self.packetFlow)")
DDLogVerbose("setTunnelNetworkSettings success \(self.packetFlow)")
completionHandler(nil)
self.proxyServer = GCDHTTPProxyServer(address: IPAddress(fromString: self.proxyServerAddress), port: Port(port: self.proxyServerPort))
@@ -92,7 +91,7 @@ class PacketTunnelProvider: NEPacketTunnelProvider {
completionHandler(nil)
}
catch let proxyError {
ZLog("Error starting proxy server \(proxyError)")
DDLogError("Error starting proxy server \(proxyError)")
completionHandler(proxyError)
}
}
@@ -100,24 +99,19 @@ class PacketTunnelProvider: NEPacketTunnelProvider {
override func stopTunnel(with reason: NEProviderStopReason, completionHandler: @escaping () -> Void) {
ZLog("stopTunnel")
DDLogVerbose("stopTunnel with reason: \(reason)")
db = nil
DNSServer.currentServer = nil
RawSocketFactory.TunnelProvider = nil
ObserverFactory.currentFactory = nil
proxyServer.stop()
proxyServer = nil
ZLog("error on stopping: \(reason)")
completionHandler()
exit(EXIT_SUCCESS)
}
override func handleAppMessage(_ messageData: Data, completionHandler: ((Data?) -> Void)?) {
ZLog("handleAppMessage")
DDLogVerbose("handleAppMessage")
reloadDomainFilter()
}
}
fileprivate func ZLog(_ message: String) {
NSLog("TUN: \(message)")
}

View File

@@ -1,10 +0,0 @@
$(SRCROOT)/Carthage/Build/iOS/CocoaAsyncSocket.framework
$(SRCROOT)/Carthage/Build/iOS/CocoaLumberjack.framework
$(SRCROOT)/Carthage/Build/iOS/CocoaLumberjackSwift.framework
$(SRCROOT)/Carthage/Build/iOS/lwip.framework
$(SRCROOT)/Carthage/Build/iOS/MMDB.framework
$(SRCROOT)/Carthage/Build/iOS/NEKit.framework
$(SRCROOT)/Carthage/Build/iOS/Resolver.framework
$(SRCROOT)/Carthage/Build/iOS/Sodium.framework
$(SRCROOT)/Carthage/Build/iOS/tun2socks.framework
$(SRCROOT)/Carthage/Build/iOS/Yaml.framework

View File

@@ -1,10 +0,0 @@
$(BUILT_PRODUCTS_DIR)/$(FRAMEWORKS_FOLDER_PATH)/CocoaAsyncSocket.framework
$(BUILT_PRODUCTS_DIR)/$(FRAMEWORKS_FOLDER_PATH)/CocoaLumberjack.framework
$(BUILT_PRODUCTS_DIR)/$(FRAMEWORKS_FOLDER_PATH)/CocoaLumberjackSwift.framework
$(BUILT_PRODUCTS_DIR)/$(FRAMEWORKS_FOLDER_PATH)/lwip.framework
$(BUILT_PRODUCTS_DIR)/$(FRAMEWORKS_FOLDER_PATH)/MMDB.framework
$(BUILT_PRODUCTS_DIR)/$(FRAMEWORKS_FOLDER_PATH)/NEKit.framework
$(BUILT_PRODUCTS_DIR)/$(FRAMEWORKS_FOLDER_PATH)/Resolver.framework
$(BUILT_PRODUCTS_DIR)/$(FRAMEWORKS_FOLDER_PATH)/Sodium.framework
$(BUILT_PRODUCTS_DIR)/$(FRAMEWORKS_FOLDER_PATH)/tun2socks.framework
$(BUILT_PRODUCTS_DIR)/$(FRAMEWORKS_FOLDER_PATH)/Yaml.framework

View File

@@ -0,0 +1,18 @@
//
// CocoaAsyncSocket.h
// CocoaAsyncSocket
//
// Created by Derek Clarkson on 10/08/2015.
// CocoaAsyncSocket project is in the public domain.
//
@import Foundation;
//! Project version number for CocoaAsyncSocket.
FOUNDATION_EXPORT double cocoaAsyncSocketVersionNumber;
//! Project version string for CocoaAsyncSocket.
FOUNDATION_EXPORT const unsigned char cocoaAsyncSocketVersionString[];
#import "GCDAsyncSocket.h"
#import "GCDAsyncUdpSocket.h"

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,40 @@
import Foundation
public enum AdapterSocketEvent: EventType {
public var description: String {
switch self {
case let .socketOpened(socket, withSession: session):
return "Adatper socket \(socket) starts to connect to remote with session \(session)."
case .disconnectCalled(let socket):
return "Disconnect is just called on adapter socket \(socket)."
case .forceDisconnectCalled(let socket):
return "Force disconnect is just called on adapter socket \(socket)."
case .disconnected(let socket):
return "Adapter socket \(socket) disconnected."
case let .readData(data, on: socket):
return "Received \(data.count) bytes data on adatper socket \(socket)."
case let .wroteData(data, on: socket):
if let data = data {
return "Sent \(data.count) bytes data on adapter socket \(socket)."
} else {
return "Sent data on adapter socket \(socket)."
}
case let .connected(socket):
return "Adapter socket \(socket) connected to remote."
case .readyForForward(let socket):
return "Adatper socket \(socket) is ready to forward data."
case let .errorOccured(error, on: socket):
return "Adapter socket \(socket) encountered an error \(error)."
}
}
case socketOpened(AdapterSocket, withSession: ConnectSession),
disconnectCalled(AdapterSocket),
forceDisconnectCalled(AdapterSocket),
disconnected(AdapterSocket),
readData(Data, on: AdapterSocket),
wroteData(Data?, on: AdapterSocket),
connected(AdapterSocket),
readyForForward(AdapterSocket),
errorOccured(Error, on: AdapterSocket)
}

View File

@@ -0,0 +1,3 @@
import Foundation
public protocol EventType: CustomStringConvertible {}

View File

@@ -0,0 +1,18 @@
import Foundation
public enum ProxyServerEvent: EventType {
public var description: String {
switch self {
case let .newSocketAccepted(socket, onServer: server):
return "Proxy server \(server) just accepted a new socket \(socket)."
case let .tunnelClosed(tunnel, onServer: server):
return "A tunnel \(tunnel) on proxy server \(server) just closed."
case .started(let server):
return "Proxy server \(server) started."
case .stopped(let server):
return "Proxy server \(server) stopped."
}
}
case newSocketAccepted(ProxySocket, onServer: ProxyServer), tunnelClosed(Tunnel, onServer: ProxyServer), started(ProxyServer), stopped(ProxyServer)
}

View File

@@ -0,0 +1,43 @@
import Foundation
public enum ProxySocketEvent: EventType {
public var description: String {
switch self {
case .socketOpened(let socket):
return "Start processing data from proxy socket \(socket)."
case .disconnectCalled(let socket):
return "Disconnect is just called on proxy socket \(socket)."
case .forceDisconnectCalled(let socket):
return "Force disconnect is just called on proxy socket \(socket)."
case .disconnected(let socket):
return "Proxy socket \(socket) disconnected."
case let .receivedRequest(session, on: socket):
return "Proxy socket \(socket) received request \(session)."
case let .readData(data, on: socket):
return "Received \(data.count) bytes data on proxy socket \(socket)."
case let .wroteData(data, on: socket):
if let data = data {
return "Sent \(data.count) bytes data on proxy socket \(socket)."
} else {
return "Sent data on proxy socket \(socket)."
}
case let .askedToResponseTo(adapter, on: socket):
return "Proxy socket \(socket) is asked to respond to adapter \(adapter)."
case .readyForForward(let socket):
return "Proxy socket \(socket) is ready to forward data."
case let .errorOccured(error, on: socket):
return "Proxy socket \(socket) encountered an error \(error)."
}
}
case socketOpened(ProxySocket),
disconnectCalled(ProxySocket),
forceDisconnectCalled(ProxySocket),
disconnected(ProxySocket),
receivedRequest(ConnectSession, on: ProxySocket),
readData(Data, on: ProxySocket),
wroteData(Data?, on: ProxySocket),
askedToResponseTo(AdapterSocket, on: ProxySocket),
readyForForward(ProxySocket),
errorOccured(Error, on: ProxySocket)
}

View File

@@ -0,0 +1,16 @@
import Foundation
public enum RuleMatchEvent: EventType {
public var description: String {
switch self {
case let .ruleMatched(session, rule: rule):
return "Rule \(rule) matched session \(session)."
case let .ruleDidNotMatch(session, rule: rule):
return "Rule \(rule) did not match session \(session)."
case let .dnsRuleMatched(session, rule: rule, type: type, result: result):
return "Rule \(rule) matched DNS session \(session) of type \(type), the result is \(result)."
}
}
case ruleMatched(ConnectSession, rule: Rule), ruleDidNotMatch(ConnectSession, rule: Rule), dnsRuleMatched(DNSSession, rule: Rule, type: DNSSessionMatchType, result: DNSSessionMatchResult)
}

View File

@@ -0,0 +1,57 @@
import Foundation
public enum TunnelEvent: EventType {
public var description: String {
switch self {
case .opened(let tunnel):
return "Tunnel \(tunnel) starts processing data."
case .closeCalled(let tunnel):
return "Close is called on tunnel \(tunnel)."
case .forceCloseCalled(let tunnel):
return "Force close is called on tunnel \(tunnel)."
case let .receivedRequest(request, from: socket, on: tunnel):
return "Tunnel \(tunnel) received request \(request) from proxy socket \(socket)."
case let .receivedReadySignal(socket, currentReady: signal, on: tunnel):
if signal == 1 {
return "Tunnel \(tunnel) received ready-for-forward signal from socket \(socket)."
} else {
return "Tunnel \(tunnel) received ready-for-forward signal from socket \(socket). Start forwarding data."
}
case let .proxySocketReadData(data, from: socket, on: tunnel):
return "Tunnel \(tunnel) received \(data.count) bytes from proxy socket \(socket)."
case let .proxySocketWroteData(data, by: socket, on: tunnel):
if let data = data {
return "Proxy socket \(socket) sent \(data.count) bytes data from Tunnel \(tunnel)."
} else {
return "Proxy socket \(socket) sent data from Tunnel \(tunnel)."
}
case let .adapterSocketReadData(data, from: socket, on: tunnel):
return "Tunnel \(tunnel) received \(data.count) bytes from adapter socket \(socket)."
case let .adapterSocketWroteData(data, by: socket, on: tunnel):
if let data = data {
return "Adatper socket \(socket) sent \(data.count) bytes data from Tunnel \(tunnel)."
} else {
return "Adapter socket \(socket) sent data from Tunnel \(tunnel)."
}
case let .connectedToRemote(socket, on: tunnel):
return "Adapter socket \(socket) connected to remote successfully on tunnel \(tunnel)."
case let .updatingAdapterSocket(from: old, to: new, on: tunnel):
return "Updating adapter socket of tunnel \(tunnel) from \(old) to \(new)."
case .closed(let tunnel):
return "Tunnel \(tunnel) closed."
}
}
case opened(Tunnel),
closeCalled(Tunnel),
forceCloseCalled(Tunnel),
receivedRequest(ConnectSession, from: ProxySocket, on: Tunnel),
receivedReadySignal(SocketProtocol, currentReady: Int, on: Tunnel),
proxySocketReadData(Data, from: ProxySocket, on: Tunnel),
proxySocketWroteData(Data?, by: ProxySocket, on: Tunnel),
adapterSocketReadData(Data, from: AdapterSocket, on: Tunnel),
adapterSocketWroteData(Data?, by: AdapterSocket, on: Tunnel),
connectedToRemote(AdapterSocket, on: Tunnel),
updatingAdapterSocket(from: AdapterSocket, to: AdapterSocket, on: Tunnel),
closed(Tunnel)
}

View File

@@ -0,0 +1,6 @@
import Foundation
open class Observer<T: EventType> {
public init() {}
open func signal(_ event: T) {}
}

View File

@@ -0,0 +1,27 @@
import Foundation
open class ObserverFactory {
public static var currentFactory: ObserverFactory?
public init() {}
open func getObserverForTunnel(_ tunnel: Tunnel) -> Observer<TunnelEvent>? {
return nil
}
open func getObserverForAdapterSocket(_ socket: AdapterSocket) -> Observer<AdapterSocketEvent>? {
return nil
}
open func getObserverForProxySocket(_ socket: ProxySocket) -> Observer<ProxySocketEvent>? {
return nil
}
open func getObserverForProxyServer(_ server: ProxyServer) -> Observer<ProxyServerEvent>? {
return nil
}
open func getObserverForRuleManager(_ manager: RuleManager) -> Observer<RuleMatchEvent>? {
return nil
}
}

View File

@@ -0,0 +1,12 @@
import Foundation
struct GlobalIntializer {
private static let _initialized: Bool = {
Resolver.queue = QueueFactory.getQueue()
return true
}()
static func initalize() {
_ = _initialized
}
}

View File

@@ -0,0 +1,18 @@
import Foundation
public enum DNSType: UInt16 {
// swiftlint:disable:next type_name
case invalid = 0, a, ns, md, mf, cname, soa, mb, mg, mr, null, wks, ptr, hinfo, minfo, mx, txt, rp, afsdb, x25, isdn, rt, nsap, nsapptr, sig, key, px, gpos, aaaa, loc, nxt, eid, nimloc, srv, atma, naptr, kx, cert, a6, dname, sink, opt, apl, ds, sshfp, rrsig = 46, nsec, dnskey, tkey = 249, tsig, ixfr, axfr, mailb, maila, any
}
public enum DNSMessageType: UInt8 {
case query, response
}
public enum DNSReturnStatus: UInt8 {
case success = 0, formatError, serverFailure, nameError, notImplemented, refused
}
public enum DNSClass: UInt16 {
case internet = 1
}

View File

@@ -0,0 +1,389 @@
import Foundation
open class DNSMessage {
// var sourceAddress: IPv4Address?
// var sourcePort: Port?
// var destinationAddress: IPv4Address?
// var destinationPort: Port?
open var transactionID: UInt16 = 0
open var messageType: DNSMessageType = .query
open var authoritative: Bool = false
open var truncation: Bool = false
open var recursionDesired: Bool = false
open var recursionAvailable: Bool = false
open var status: DNSReturnStatus = .success
open var queries: [DNSQuery] = []
open var answers: [DNSResource] = []
open var nameservers: [DNSResource] = []
open var addtionals: [DNSResource] = []
var payload: Data!
var bytesLength: Int {
var len = 12 + queries.reduce(0) {
$0 + $1.bytesLength
}
len += answers.reduce(0) {
$0 + $1.bytesLength
}
len += nameservers.reduce(0) {
$0 + $1.bytesLength
}
len += addtionals.reduce(0) {
$0 + $1.bytesLength
}
return len
}
var resolvedIPv4Address: IPAddress? {
for answer in answers {
if let address = answer.ipv4Address {
return address
}
}
return nil
}
var type: DNSType? {
return queries.first?.type
}
init() {}
init?(payload: Data) {
self.payload = payload
let scanner = BinaryDataScanner(data: payload, littleEndian: false)
transactionID = scanner.read16()!
var bytes = scanner.readByte()!
if bytes & 0x80 > 0 {
messageType = .response
} else {
messageType = .query
}
// ignore OP code
authoritative = bytes & 0x04 > 0
truncation = bytes & 0x02 > 0
recursionDesired = bytes & 0x01 > 0
bytes = scanner.readByte()!
recursionAvailable = bytes & 0x80 > 0
if let status = DNSReturnStatus(rawValue: bytes & 0x0F) {
self.status = status
} else {
DDLogError("Received DNS response with unknown status: \(bytes & 0x0F).")
self.status = .serverFailure
}
let queryCount = scanner.read16()!
let answerCount = scanner.read16()!
let nameserverCount = scanner.read16()!
let addtionalCount = scanner.read16()!
for _ in 0..<queryCount {
queries.append(DNSQuery(payload: payload, offset: scanner.position, base: 0)!)
scanner.advance(by: queries.last!.bytesLength)
}
for _ in 0..<answerCount {
answers.append(DNSResource(payload: payload, offset: scanner.position, base: 0)!)
scanner.advance(by: answers.last!.bytesLength)
}
for _ in 0..<nameserverCount {
nameservers.append(DNSResource(payload: payload, offset: scanner.position, base: 0)!)
scanner.advance(by: nameservers.last!.bytesLength)
}
for _ in 0..<addtionalCount {
addtionals.append(DNSResource(payload: payload, offset: scanner.position, base: 0)!)
scanner.advance(by: addtionals.last!.bytesLength)
}
}
func buildMessage() -> Bool {
payload = Data(count: bytesLength)
if transactionID == 0 {
transactionID = UInt16(arc4random_uniform(UInt32(UInt16.max)))
}
setPayloadWithUInt16(transactionID, at: 0, swap: true)
var byte: UInt8 = 0
byte += messageType.rawValue << 7
if authoritative {
byte += 4
}
if truncation {
byte += 2
}
if recursionDesired {
byte += 1
}
setPayloadWithUInt8(byte, at: 2)
byte = 0
if recursionAvailable {
byte += 128
}
byte += status.rawValue
setPayloadWithUInt8(byte, at: 3)
setPayloadWithUInt16(UInt16(queries.count), at: 4, swap: true)
setPayloadWithUInt16(UInt16(answers.count), at: 6, swap: true)
setPayloadWithUInt16(UInt16(nameservers.count), at: 8, swap: true)
setPayloadWithUInt16(UInt16(addtionals.count), at: 10, swap: true)
return writeAllRecordAt(12)
}
// swiftlint:disable variable_name
func setPayloadWithUInt8(_ value: UInt8, at: Int) {
var v = value
withUnsafeBytes(of: &v) {
payload.replaceSubrange(at..<at+1, with: $0)
}
}
func setPayloadWithUInt16(_ value: UInt16, at: Int, swap: Bool = false) {
var v: UInt16
if swap {
v = NSSwapHostShortToBig(value)
} else {
v = value
}
withUnsafeBytes(of: &v) {
payload.replaceSubrange(at..<at+2, with: $0)
}
}
func setPayloadWithUInt32(_ value: UInt32, at: Int, swap: Bool = false) {
var v: UInt32
if swap {
v = NSSwapHostIntToBig(value)
} else {
v = value
}
withUnsafeBytes(of: &v) {
payload.replaceSubrange(at..<at+4, with: $0)
}
}
func setPayloadWithData(_ data: Data, at: Int, length: Int? = nil, from: Int = 0) {
let length = length ?? data.count - from
payload.withUnsafeMutableBytes { ptr in
data.copyBytes(to: ptr.baseAddress!.advanced(by: at).assumingMemoryBound(to: UInt8.self), from: from..<from+length)
}
}
func resetPayloadAt(_ at: Int, length: Int) {
payload.resetBytes(in: at..<at+length)
}
fileprivate func writeAllRecordAt(_ at: Int) -> Bool {
var position = at
for query in queries {
guard writeDNSQuery(query, at: position) else {
return false
}
position += query.bytesLength
}
for resources in [answers, nameservers, addtionals] {
for resource in resources {
guard writeDNSResource(resource, at: position) else {
return false
}
position += resource.bytesLength
}
}
return true
}
fileprivate func writeDNSQuery(_ query: DNSQuery, at: Int) -> Bool {
guard DNSNameConverter.setName(query.name, toData: &payload!, at: at) else {
return false
}
setPayloadWithUInt16(query.type.rawValue, at: at + query.nameBytesLength, swap: true)
setPayloadWithUInt16(query.klass.rawValue, at: at + query.nameBytesLength + 2, swap: true)
return true
}
fileprivate func writeDNSResource(_ resource: DNSResource, at: Int) -> Bool {
guard DNSNameConverter.setName(resource.name, toData: &payload!, at: at) else {
return false
}
setPayloadWithUInt16(resource.type.rawValue, at: at + resource.nameBytesLength, swap: true)
setPayloadWithUInt16(resource.klass.rawValue, at: at + resource.nameBytesLength + 2, swap: true)
setPayloadWithUInt32(resource.TTL, at: at + resource.nameBytesLength + 4, swap: true)
setPayloadWithUInt16(resource.dataLength, at: at + resource.nameBytesLength + 8, swap: true)
setPayloadWithData(resource.data, at: at + resource.nameBytesLength + 10)
return true
}
}
open class DNSQuery {
public let name: String
public let type: DNSType
public let klass: DNSClass
let nameBytesLength: Int
init(name: String, type: DNSType = .a, klass: DNSClass = .internet) {
self.name = name.trimmingCharacters(in: CharacterSet(charactersIn: "."))
self.type = type
self.klass = klass
self.nameBytesLength = name.utf8.count + 2
}
init?(payload: Data, offset: Int, base: Int = 0) {
(self.name, self.nameBytesLength) = DNSNameConverter.getNamefromData(payload, offset: offset, base: base)
let scanner = BinaryDataScanner(data: payload, littleEndian: false)
scanner.skip(to: offset + self.nameBytesLength)
guard let type = DNSType(rawValue: scanner.read16()!) else {
DDLogError("Received DNS packet with unknown type.")
return nil
}
self.type = type
guard let klass = DNSClass(rawValue: scanner.read16()!) else {
DDLogError("Received DNS packet with unknown class.")
return nil
}
self.klass = klass
}
var bytesLength: Int {
return nameBytesLength + 4
}
}
open class DNSResource {
public let name: String
public let type: DNSType
public let klass: DNSClass
public let TTL: UInt32
let dataLength: UInt16
public let data: Data
let nameBytesLength: Int
init(name: String, type: DNSType = .a, klass: DNSClass = .internet, TTL: UInt32 = 300, data: Data) {
self.name = name
self.type = type
self.klass = klass
self.TTL = TTL
dataLength = UInt16(data.count)
self.data = data
self.nameBytesLength = name.utf8.count + 2
}
static func ARecord(_ name: String, TTL: UInt32 = 300, address: IPAddress) -> DNSResource {
return DNSResource(name: name, type: .a, klass: .internet, TTL: TTL, data: address.dataInNetworkOrder)
}
init?(payload: Data, offset: Int, base: Int = 0) {
(self.name, self.nameBytesLength) = DNSNameConverter.getNamefromData(payload, offset: offset, base: base)
let scanner = BinaryDataScanner(data: payload, littleEndian: false)
scanner.skip(to: offset + self.nameBytesLength)
guard let type = DNSType(rawValue: scanner.read16()!) else {
DDLogError("Received DNS packet with unknown type.")
return nil
}
self.type = type
guard let klass = DNSClass(rawValue: scanner.read16()!) else {
DDLogError("Received DNS packet with unknown class.")
return nil
}
self.klass = klass
self.TTL = scanner.read32()!
dataLength = scanner.read16()!
self.data = payload.subdata(in: scanner.position..<scanner.position+Int(dataLength))
}
var bytesLength: Int {
return nameBytesLength + 10 + Int(dataLength)
}
var ipv4Address: IPAddress? {
guard type == .a else {
return nil
}
return IPAddress(fromBytesInNetworkOrder: (data as NSData).bytes)
}
}
class DNSNameConverter {
static func setName(_ name: String, toData data: inout Data, at: Int) -> Bool {
let labels = name.components(separatedBy: CharacterSet(charactersIn: "."))
var position = at
for label in labels {
let len = label.utf8.count
guard len != 0 else {
// invalid domain name
return false
}
data[position] = UInt8(len)
position += 1
data.replaceSubrange(position..<position+len, with: label.data(using: .utf8)!)
position += len
}
data[position] = 0
return true
}
static func getNamefromData(_ data: Data, offset: Int, base: Int = 0) -> (String, Int) {
let scanner = BinaryDataScanner(data: data, littleEndian: false)
scanner.skip(to: offset)
var len: UInt8 = 0
var name = ""
var currentReadBytes = 0
var jumped = false
var nameBytesLength = 0
repeat {
let length = scanner.read16()!
// is this a pointer?
if length & 0xC000 == 0xC000 {
if !jumped {
// save the length position
nameBytesLength = 2 + currentReadBytes
jumped = true
}
scanner.skip(to: Int(length & 0x3FFF) + base)
} else {
scanner.advance(by: -2)
}
len = scanner.readByte()!
currentReadBytes += 1
if len == 0 {
break
}
currentReadBytes += Int(len)
guard let label = String(bytes: scanner.data.subdata(in: scanner.position..<scanner.position+Int(len)), encoding: .utf8) else {
return ("", currentReadBytes)
}
// this is not efficient, but won't take much time, so maybe I'll optimize it later
name = name.appendingFormat(".%@", label)
scanner.advance(by: Int(len))
} while true
if !jumped {
nameBytesLength = currentReadBytes
}
return (name.trimmingCharacters(in: CharacterSet(charactersIn: ".")), nameBytesLength)
}
}

View File

@@ -0,0 +1,37 @@
import Foundation
public protocol DNSResolverProtocol: class {
var delegate: DNSResolverDelegate? { get set }
func resolve(session: DNSSession)
func stop()
}
public protocol DNSResolverDelegate: class {
func didReceive(rawResponse: Data)
}
open class UDPDNSResolver: DNSResolverProtocol, NWUDPSocketDelegate {
let socket: NWUDPSocket
public weak var delegate: DNSResolverDelegate?
public init(address: IPAddress, port: Port) {
socket = NWUDPSocket(host: address.presentation, port: Int(port.value))!
socket.delegate = self
}
public func resolve(session: DNSSession) {
socket.write(data: session.requestMessage.payload)
}
public func stop() {
socket.disconnect()
}
public func didReceive(data: Data, from: NWUDPSocket) {
delegate?.didReceive(rawResponse: data)
}
public func didCancel(socket: NWUDPSocket) {
}
}

View File

@@ -0,0 +1,269 @@
import Foundation
import NetworkExtension
/// A DNS server designed as an `IPStackProtocol` implementation which works with TUN interface.
///
/// This class is thread-safe.
open class DNSServer: DNSResolverDelegate, IPStackProtocol {
/// Current DNS server.
///
/// - warning: There is at most one DNS server running at the same time. If a DNS server is registered to `TUNInterface` then it must also be set here.
public static var currentServer: DNSServer?
/// The address of DNS server.
let serverAddress: IPAddress
/// The port of DNS server
let serverPort: Port
fileprivate let queue: DispatchQueue = QueueFactory.getQueue()
fileprivate var fakeSessions: [IPAddress: DNSSession] = [:]
fileprivate var pendingSessions: [UInt16: DNSSession] = [:]
fileprivate let pool: IPPool?
fileprivate var resolvers: [DNSResolverProtocol] = []
open var outputFunc: (([Data], [NSNumber]) -> Void)!
// Only match A record as of now, all other records should be passed directly.
fileprivate let matchedType = [DNSType.a]
/**
Initailize a DNS server.
- parameter address: The IP address of the server.
- parameter port: The listening port of the server.
- parameter fakeIPPool: The pool of fake IP addresses. Set to nil if no fake IP is needed.
*/
public init(address: IPAddress, port: Port, fakeIPPool: IPPool? = nil) {
serverAddress = address
serverPort = port
pool = fakeIPPool
}
/**
Clean up fake IP.
- parameter address: The fake IP address.
- parameter delay: How long should the fake IP be valid.
*/
fileprivate func cleanUpFakeIP(_ address: IPAddress, after delay: Int) {
queue.asyncAfter(deadline: DispatchTime.now() + Double(Int64(delay) * Int64(NSEC_PER_SEC)) / Double(NSEC_PER_SEC)) {
[weak self] in
_ = self?.fakeSessions.removeValue(forKey: address)
self?.pool?.release(ip: address)
}
}
/**
Clean up pending session.
- parameter session: The pending session.
- parameter delay: How long before the pending session be cleaned up.
*/
fileprivate func cleanUpPendingSession(_ session: DNSSession, after delay: Int) {
queue.asyncAfter(deadline: DispatchTime.now() + Double(Int64(delay) * Int64(NSEC_PER_SEC)) / Double(NSEC_PER_SEC)) {
[weak self] in
_ = self?.pendingSessions.removeValue(forKey: session.requestMessage.transactionID)
}
}
fileprivate func lookup(_ session: DNSSession) {
guard shouldMatch(session) else {
session.matchResult = .real
lookupRemotely(session)
return
}
RuleManager.currentManager.matchDNS(session, type: .domain)
switch session.matchResult! {
case .fake:
guard setUpFakeIP(session) else {
// failed to set up a fake IP, return the result directly
session.matchResult = .real
lookupRemotely(session)
return
}
outputSession(session)
case .real, .unknown:
lookupRemotely(session)
default:
DDLogError("The rule match result should never be .Pass.")
}
}
fileprivate func lookupRemotely(_ session: DNSSession) {
pendingSessions[session.requestMessage.transactionID] = session
cleanUpPendingSession(session, after: Opt.DNSPendingSessionLifeTime)
sendQueryToRemote(session)
}
fileprivate func sendQueryToRemote(_ session: DNSSession) {
for resolver in resolvers {
resolver.resolve(session: session)
}
}
/**
Input IP packet into the DNS server.
- parameter packet: The IP packet.
- parameter version: The version of the IP packet.
- returns: If the packet is taken in by this DNS server.
*/
open func input(packet: Data, version: NSNumber?) -> Bool {
guard IPPacket.peekProtocol(packet) == .udp else {
return false
}
guard IPPacket.peekDestinationAddress(packet) == serverAddress else {
return false
}
guard IPPacket.peekDestinationPort(packet) == serverPort else {
return false
}
guard let ipPacket = IPPacket(packetData: packet) else {
return false
}
guard let session = DNSSession(packet: ipPacket) else {
return false
}
queue.async {
self.lookup(session)
}
return true
}
public func start() {
}
open func stop() {
for resolver in resolvers {
resolver.stop()
}
resolvers = []
// The blocks scheduled with `dispatch_after` are ignored since they are hard to cancel. But there should be no consequence, everything will be released except for a few `IPAddress`es and the `queue` which will be released later.
}
fileprivate func outputSession(_ session: DNSSession) {
guard let result = session.matchResult else {
return
}
let udpParser = UDPProtocolParser()
udpParser.sourcePort = serverPort
// swiftlint:disable:next force_cast
udpParser.destinationPort = (session.requestIPPacket!.protocolParser as! UDPProtocolParser).sourcePort
switch result {
case .real:
udpParser.payload = session.realResponseMessage!.payload
case .fake:
let response = DNSMessage()
response.transactionID = session.requestMessage.transactionID
response.messageType = .response
response.recursionAvailable = true
// since we only support ipv4 as of now, it must be an answer of type A
response.answers.append(DNSResource.ARecord(session.requestMessage.queries[0].name, TTL: UInt32(Opt.DNSFakeIPTTL), address: session.fakeIP!))
session.expireAt = Date().addingTimeInterval(Double(Opt.DNSFakeIPTTL))
guard response.buildMessage() else {
DDLogError("Failed to build DNS response.")
return
}
udpParser.payload = response.payload
default:
return
}
let ipPacket = IPPacket()
ipPacket.sourceAddress = serverAddress
ipPacket.destinationAddress = session.requestIPPacket!.sourceAddress
ipPacket.protocolParser = udpParser
ipPacket.transportProtocol = .udp
ipPacket.buildPacket()
outputFunc([ipPacket.packetData], [NSNumber(value: AF_INET as Int32)])
}
fileprivate func shouldMatch(_ session: DNSSession) -> Bool {
return matchedType.contains(session.requestMessage.type!)
}
func isFakeIP(_ ipAddress: IPAddress) -> Bool {
return pool?.contains(ip: ipAddress) ?? false
}
func lookupFakeIP(_ address: IPAddress) -> DNSSession? {
var session: DNSSession?
QueueFactory.executeOnQueueSynchronizedly {
session = self.fakeSessions[address]
}
return session
}
/**
Add new DNS resolver to DNS server.
- parameter resolver: The resolver to add.
*/
open func registerResolver(_ resolver: DNSResolverProtocol) {
resolver.delegate = self
resolvers.append(resolver)
}
fileprivate func setUpFakeIP(_ session: DNSSession) -> Bool {
guard let fakeIP = pool?.fetchIP() else {
DDLogVerbose("Failed to get a fake IP.")
return false
}
session.fakeIP = fakeIP
fakeSessions[fakeIP] = session
session.expireAt = Date().addingTimeInterval(TimeInterval(Opt.DNSFakeIPTTL))
// keep the fake session for 2 TTL
cleanUpFakeIP(fakeIP, after: Opt.DNSFakeIPTTL * 2)
return true
}
open func didReceive(rawResponse: Data) {
guard let message = DNSMessage(payload: rawResponse) else {
DDLogError("Failed to parse response from remote DNS server.")
return
}
queue.async {
guard let session = self.pendingSessions.removeValue(forKey: message.transactionID) else {
// this should not be a problem if there are multiple DNS servers or the DNS server is hijacked.
DDLogVerbose("Do not find the corresponding DNS session for the response.")
return
}
session.realResponseMessage = message
session.realIP = message.resolvedIPv4Address
if session.matchResult != .fake && session.matchResult != .real {
RuleManager.currentManager.matchDNS(session, type: .ip)
}
switch session.matchResult! {
case .fake:
if !self.setUpFakeIP(session) {
// return real response
session.matchResult = .real
}
self.outputSession(session)
case .real:
self.outputSession(session)
default:
DDLogError("The rule match result should never be .Pass or .Unknown in IP mode.")
}
}
}
}

View File

@@ -0,0 +1,49 @@
import Foundation
open class DNSSession {
public let requestMessage: DNSMessage
var requestIPPacket: IPPacket?
open var realIP: IPAddress?
open var fakeIP: IPAddress?
open var realResponseMessage: DNSMessage?
var realResponseIPPacket: IPPacket?
open var matchedRule: Rule?
open var matchResult: DNSSessionMatchResult?
var indexToMatch = 0
var expireAt: Date?
// lazy var countryCode: String? = {
// [unowned self] in
// guard self.realIP != nil else {
// return nil
// }
// return Utils.GeoIPLookup.Lookup(self.realIP!.presentation)
// }()
init?(message: DNSMessage) {
guard message.messageType == .query else {
DDLogError("DNSSession can only be initailized by a DNS query.")
return nil
}
guard message.queries.count == 1 else {
DDLogError("Expecting the DNS query has exact one query entry.")
return nil
}
requestMessage = message
}
convenience init?(packet: IPPacket) {
guard let message = DNSMessage(payload: packet.protocolParser.payload) else {
return nil
}
self.init(message: message)
requestIPPacket = packet
}
}
extension DNSSession: CustomStringConvertible {
public var description: String {
return "<\(type(of: self)) domain: \(self.requestMessage.queries.first!.name) realIP: \(String(describing: realIP)) fakeIP: \(String(describing: fakeIP))>"
}
}

View File

@@ -0,0 +1,34 @@
import Foundation
/// The protocol defines an IP stack.
public protocol IPStackProtocol: class {
/**
Input a packet into the stack.
- parameter packet: The IP packet.
- parameter version: The version of the IP packet, i.e., AF_INET, AF_INET6.
- returns: If the stack takes in this packet. If the packet is taken in, then it won't be processed by other IP stacks.
*/
func input(packet: Data, version: NSNumber?) -> Bool
/// This is called when this stack decided to output some IP packet. This is set automatically when the stack is registered to some interface.
///
/// The parameter is the safe as the `inputPacket`.
///
/// - note: This block is thread-safe.
var outputFunc: (([Data], [NSNumber]) -> Void)! { get set }
func start()
/**
Stop the stack from running.
This is called when the interface this stack is registered to stop to processing packets and will be released soon.
*/
func stop()
}
extension IPStackProtocol {
public func stop() {}
}

View File

@@ -0,0 +1,76 @@
//import Foundation
//
//enum ChangeType {
// case Address, Port
//}
//
//public class IPMutablePacket {
// // Support only IPv4 for now
//
// let version: IPVersion
// let proto: TransportType
// let IPHeaderLength: Int
// var sourceAddress: IPv4Address {
// get {
// return IPv4Address(fromBytesInNetworkOrder: payload.bytes.advancedBy(12))
// }
// set {
// setIPv4Address(sourceAddress, newAddress: newValue, at: 12)
// }
// }
// var destinationAddress: IPv4Address {
// get {
// return IPv4Address(fromBytesInNetworkOrder: payload.bytes.advancedBy(16))
// }
// set {
// setIPv4Address(destinationAddress, newAddress: newValue, at: 16)
// }
// }
//
// let payload: NSMutableData
//
// public init(payload: NSData) {
// let vl = UnsafePointer<UInt8>(payload.bytes).memory
// version = IPVersion(rawValue: vl >> 4)!
// IPHeaderLength = Int(vl & 0x0F) * 4
// let p = UnsafePointer<UInt8>(payload.bytes.advancedBy(9)).memory
// proto = TransportType(rawValue: p)!
// self.payload = NSMutableData(data: payload)
// }
//
// func updateChecksum(oldValue: UInt16, newValue: UInt16, type: ChangeType) {
// if type == .Address {
// updateChecksum(oldValue, newValue: newValue, at: 10)
// }
// }
//
// // swiftlint:disable:next variable_name
// internal func updateChecksum(oldValue: UInt16, newValue: UInt16, at: Int) {
// let oldChecksum = UnsafePointer<UInt16>(payload.bytes.advancedBy(at)).memory
// let oc32 = UInt32(~oldChecksum)
// let ov32 = UInt32(~oldValue)
// let nv32 = UInt32(newValue)
// var newChecksum32 = oc32 &+ ov32 &+ nv32
// newChecksum32 = (newChecksum32 & 0xFFFF) + (newChecksum32 >> 16)
// newChecksum32 = (newChecksum32 & 0xFFFF) &+ (newChecksum32 >> 16)
// var newChecksum = ~UInt16(newChecksum32)
// payload.replaceBytesInRange(NSRange(location: at, length: 2), withBytes: &newChecksum, length: 2)
// }
//
// // swiftlint:disable:next variable_name
// private func foldChecksum(checksum: UInt32) -> UInt32 {
// var checksum = checksum
// while checksum > 0xFFFF {
// checksum = (checksum & 0xFFFF) + (checksum >> 16)
// }
// return checksum
// }
//
// // swiftlint:disable:next variable_name
// private func setIPv4Address(oldAddress: IPv4Address, newAddress: IPv4Address, at: Int) {
// payload.replaceBytesInRange(NSRange(location: at, length: 4), withBytes: newAddress.bytesInNetworkOrder, length: 4)
// updateChecksum(UnsafePointer<UInt16>(oldAddress.bytesInNetworkOrder).memory, newValue: UnsafePointer<UInt16>(newAddress.bytesInNetworkOrder).memory, type: .Address)
// updateChecksum(UnsafePointer<UInt16>(oldAddress.bytesInNetworkOrder).advancedBy(1).memory, newValue: UnsafePointer<UInt16>(newAddress.bytesInNetworkOrder).advancedBy(1).memory, type: .Address)
// }
//
//}

View File

@@ -0,0 +1,330 @@
import Foundation
public enum IPVersion: UInt8 {
case iPv4 = 4, iPv6 = 6
}
public enum TransportProtocol: UInt8 {
case icmp = 1, tcp = 6, udp = 17
}
/// The class to process and build IP packet.
///
/// - note: Only IPv4 is supported as of now.
open class IPPacket {
/**
Get the version of the IP Packet without parsing the whole packet.
- parameter data: The data containing the whole IP packet.
- returns: The version of the packet. Returns `nil` if failed to parse the packet.
*/
public static func peekIPVersion(_ data: Data) -> IPVersion? {
guard data.count >= 20 else {
return nil
}
let version = (data as NSData).bytes.bindMemory(to: UInt8.self, capacity: data.count).pointee >> 4
return IPVersion(rawValue: version)
}
/**
Get the protocol of the IP Packet without parsing the whole packet.
- parameter data: The data containing the whole IP packet.
- returns: The protocol of the packet. Returns `nil` if failed to parse the packet.
*/
public static func peekProtocol(_ data: Data) -> TransportProtocol? {
guard data.count >= 20 else {
return nil
}
return TransportProtocol(rawValue: (data as NSData).bytes.bindMemory(to: UInt8.self, capacity: data.count).advanced(by: 9).pointee)
}
/**
Get the source IP address of the IP packet without parsing the whole packet.
- parameter data: The data containing the whole IP packet.
- returns: The source IP address of the packet. Returns `nil` if failed to parse the packet.
*/
public static func peekSourceAddress(_ data: Data) -> IPAddress? {
guard data.count >= 20 else {
return nil
}
return IPAddress(fromBytesInNetworkOrder: (data as NSData).bytes.advanced(by: 12))
}
/**
Get the destination IP address of the IP packet without parsing the whole packet.
- parameter data: The data containing the whole IP packet.
- returns: The destination IP address of the packet. Returns `nil` if failed to parse the packet.
*/
public static func peekDestinationAddress(_ data: Data) -> IPAddress? {
guard data.count >= 20 else {
return nil
}
return IPAddress(fromBytesInNetworkOrder: (data as NSData).bytes.advanced(by: 16))
}
/**
Get the source port of the IP packet without parsing the whole packet.
- parameter data: The data containing the whole IP packet.
- returns: The source IP address of the packet. Returns `nil` if failed to parse the packet.
- note: Only TCP and UDP packet has port field.
*/
public static func peekSourcePort(_ data: Data) -> Port? {
guard let proto = peekProtocol(data) else {
return nil
}
guard proto == .tcp || proto == .udp else {
return nil
}
let headerLength = Int((data as NSData).bytes.bindMemory(to: UInt8.self, capacity: data.count).pointee & 0x0F * 4)
// Make sure there are bytes for source and destination bytes.
guard data.count > headerLength + 4 else {
return nil
}
return Port(bytesInNetworkOrder: (data as NSData).bytes.advanced(by: headerLength))
}
/**
Get the destination port of the IP packet without parsing the whole packet.
- parameter data: The data containing the whole IP packet.
- returns: The destination IP address of the packet. Returns `nil` if failed to parse the packet.
- note: Only TCP and UDP packet has port field.
*/
public static func peekDestinationPort(_ data: Data) -> Port? {
guard let proto = peekProtocol(data) else {
return nil
}
guard proto == .tcp || proto == .udp else {
return nil
}
let headerLength = Int((data as NSData).bytes.bindMemory(to: UInt8.self, capacity: data.count).pointee & 0x0F * 4)
// Make sure there are bytes for source and destination bytes.
guard data.count > headerLength + 4 else {
return nil
}
return Port(bytesInNetworkOrder: (data as NSData).bytes.advanced(by: headerLength + 2))
}
/// The version of the current IP packet.
open var version: IPVersion = .iPv4
/// The length of the IP packet header.
open var headerLength: UInt8 = 20
/// This contains the DSCP and ECN of the IP packet.
///
/// - note: Since we can not send custom IP packet out with NetworkExtension, this is useless and simply ignored.
open var tos: UInt8 = 0
/// This should be the length of the datagram.
/// This value is not read from header since NEPacketTunnelFlow has already taken care of it for us.
open var totalLength: UInt16 {
return UInt16(packetData.count)
}
/// Identification of the current packet.
///
/// - note: Since we do not support fragment, this is ignored and always will be zero.
/// - note: Theoratically, this should be a sequentially increasing number. It probably will be implemented.
var identification: UInt16 = 0
/// Offset of the current packet.
///
/// - note: Since we do not support fragment, this is ignored and always will be zero.
var offset: UInt16 = 0
/// TTL of the packet.
var TTL: UInt8 = 64
/// Source IP address.
var sourceAddress: IPAddress!
/// Destination IP address.
var destinationAddress: IPAddress!
/// Transport protocol of the packet.
var transportProtocol: TransportProtocol!
/// Parser to parse the payload in IP packet.
var protocolParser: TransportProtocolParserProtocol!
/// The data representing the packet.
var packetData: Data!
/**
Initailize a new instance to build IP packet.
*/
init() {}
/**
Initailize an `IPPacket` with data.
- parameter packetData: The data containing a whole packet.
*/
init?(packetData: Data) {
// no need to validate the packet.
self.packetData = packetData
let scanner = BinaryDataScanner(data: packetData, littleEndian: false)
let vhl = scanner.readByte()!
guard let v = IPVersion(rawValue: vhl >> 4) else {
DDLogError("Got unknown ip packet version \(vhl >> 4)")
return nil
}
version = v
headerLength = vhl & 0x0F * 4
guard packetData.count >= Int(headerLength) else {
return nil
}
tos = scanner.readByte()!
guard totalLength == scanner.read16()! else {
DDLogError("Packet length mismatches from header.")
return nil
}
identification = scanner.read16()!
offset = scanner.read16()!
TTL = scanner.readByte()!
guard let proto = TransportProtocol(rawValue: scanner.readByte()!) else {
DDLogWarn("Get unsupported packet protocol.")
return nil
}
transportProtocol = proto
// ignore checksum
_ = scanner.read16()!
switch version {
case .iPv4:
sourceAddress = IPAddress(ipv4InNetworkOrder: CFSwapInt32(scanner.read32()!))
destinationAddress = IPAddress(ipv4InNetworkOrder: CFSwapInt32(scanner.read32()!))
default:
// IPv6 is not supported yet.
DDLogWarn("IPv6 is not supported yet.")
return nil
}
switch transportProtocol! {
case .udp:
guard let parser = UDPProtocolParser(packetData: packetData, offset: Int(headerLength)) else {
return nil
}
self.protocolParser = parser
default:
DDLogError("Can not parse packet header of type \(String(describing: transportProtocol)) yet")
return nil
}
}
func computePseudoHeaderChecksum() -> UInt32 {
var result: UInt32 = 0
if let address = sourceAddress {
result += address.UInt32InNetworkOrder! >> 16 + address.UInt32InNetworkOrder! & 0xFFFF
}
if let address = destinationAddress {
result += address.UInt32InNetworkOrder! >> 16 + address.UInt32InNetworkOrder! & 0xFFFF
}
result += UInt32(transportProtocol.rawValue) << 8
result += CFSwapInt32(UInt32(protocolParser.bytesLength))
return result
}
func buildPacket() {
packetData = NSMutableData(length: Int(headerLength) + protocolParser.bytesLength) as Data?
// set header
setPayloadWithUInt8(headerLength / 4 + version.rawValue << 4, at: 0)
setPayloadWithUInt8(tos, at: 1)
setPayloadWithUInt16(totalLength, at: 2)
setPayloadWithUInt16(identification, at: 4)
setPayloadWithUInt16(offset, at: 6)
setPayloadWithUInt8(TTL, at: 8)
setPayloadWithUInt8(transportProtocol.rawValue, at: 9)
// clear checksum bytes
resetPayloadAt(10, length: 2)
setPayloadWithUInt32(sourceAddress.UInt32InNetworkOrder!, at: 12, swap: false)
setPayloadWithUInt32(destinationAddress.UInt32InNetworkOrder!, at: 16, swap: false)
// let TCP or UDP packet build
protocolParser.packetData = packetData
protocolParser.offset = Int(headerLength)
protocolParser.buildSegment(computePseudoHeaderChecksum())
packetData = protocolParser.packetData
setPayloadWithUInt16(Checksum.computeChecksum(packetData, from: 0, to: Int(headerLength)), at: 10, swap: false)
}
func setPayloadWithUInt8(_ value: UInt8, at: Int) {
var v = value
withUnsafeBytes(of: &v) {
packetData.replaceSubrange(at..<at+1, with: $0)
}
}
func setPayloadWithUInt16(_ value: UInt16, at: Int, swap: Bool = true) {
var v: UInt16
if swap {
v = CFSwapInt16HostToBig(value)
} else {
v = value
}
withUnsafeBytes(of: &v) {
packetData.replaceSubrange(at..<at+2, with: $0)
}
}
func setPayloadWithUInt32(_ value: UInt32, at: Int, swap: Bool = true) {
var v: UInt32
if swap {
v = CFSwapInt32HostToBig(value)
} else {
v = value
}
withUnsafeBytes(of: &v) {
packetData.replaceSubrange(at..<at+4, with: $0)
}
}
func setPayloadWithData(_ data: Data, at: Int, length: Int? = nil, from: Int = 0) {
var length = length
if length == nil {
length = data.count - from
}
packetData.replaceSubrange(at..<at+length!, with: data)
}
func resetPayloadAt(_ at: Int, length: Int) {
packetData.resetBytes(in: at..<at+length)
}
}

View File

@@ -0,0 +1,72 @@
import Foundation
protocol TransportProtocolParserProtocol {
var packetData: Data! { get set }
var offset: Int { get set }
var bytesLength: Int { get }
var payload: Data! { get set }
func buildSegment(_ pseudoHeaderChecksum: UInt32)
}
/// Parser to process UDP packet and build packet.
class UDPProtocolParser: TransportProtocolParserProtocol {
/// The source port.
var sourcePort: Port!
/// The destination port.
var destinationPort: Port!
/// The data containing the UDP segment.
var packetData: Data!
/// The offset of the UDP segment in the `packetData`.
var offset: Int = 0
/// The payload to be encapsulated.
var payload: Data!
/// The length of the UDP segment.
var bytesLength: Int {
return payload.count + 8
}
init() {}
init?(packetData: Data, offset: Int) {
guard packetData.count >= offset + 8 else {
return nil
}
self.packetData = packetData
self.offset = offset
sourcePort = Port(bytesInNetworkOrder: (packetData as NSData).bytes.advanced(by: offset))
destinationPort = Port(bytesInNetworkOrder: (packetData as NSData).bytes.advanced(by: offset + 2))
payload = packetData.subdata(in: offset+8..<packetData.count)
}
func buildSegment(_ pseudoHeaderChecksum: UInt32) {
sourcePort.withUnsafeBufferPointer {
self.packetData.replaceSubrange(offset..<offset+2, with: $0)
}
destinationPort.withUnsafeBufferPointer {
self.packetData.replaceSubrange(offset+2..<offset+4, with: $0)
}
var length = NSSwapHostShortToBig(UInt16(bytesLength))
withUnsafeBytes(of: &length) {
packetData.replaceSubrange(offset+4..<offset+6, with: $0)
}
packetData.replaceSubrange(offset+8..<offset+8+payload.count, with: payload)
packetData.resetBytes(in: offset+6..<offset+8)
var checksum = Checksum.computeChecksum(packetData, from: offset, to: nil, withPseudoHeaderChecksum: pseudoHeaderChecksum)
withUnsafeBytes(of: &checksum) {
packetData.replaceSubrange(offset+6..<offset+8, with: $0)
}
}
}

View File

@@ -0,0 +1,32 @@
//import Foundation
//
//class TCPMutablePacket: IPMutablePacket {
// var sourcePort: Port {
// get {
// return Port(bytesInNetworkOrder: payload.bytes.advancedBy(IPHeaderLength))
// }
// set {
// setPort(sourcePort, newPort: newValue, at: 0)
// }
// }
//
// var destinationPort: Port {
// get {
// return Port(bytesInNetworkOrder: payload.bytes.advancedBy(IPHeaderLength + 2))
// }
// set {
// setPort(destinationPort, newPort: newValue, at: 2)
// }
// }
//
// override func updateChecksum(oldValue: UInt16, newValue: UInt16, type: ChangeType) {
// super.updateChecksum(oldValue, newValue: newValue, type: type)
// updateChecksum(oldValue, newValue: newValue, at: IPHeaderLength + 16)
// }
//
// // swiftlint:disable:next variable_name
// private func setPort(oldPort: Port, newPort: Port, at: Int) {
// payload.replaceBytesInRange(NSRange(location: at + IPHeaderLength, length: 2), withBytes: newPort.bytesInNetworkOrder, length: 2)
// updateChecksum(oldPort.valueInNetworkOrder, newValue: newPort.valueInNetworkOrder, type: .Port)
// }
//}

View File

@@ -0,0 +1,158 @@
import Foundation
/// Representing all the information in one connect session.
public final class ConnectSession {
public enum EventSourceEnum {
case proxy, adapter, tunnel
}
/// The requested host.
///
/// This is the host received in the request. May be a domain, a real IP or a fake IP.
public let requestedHost: String
/// The real host for this session.
///
/// If the session is initailized with a host domain, then `host == requestedHost`.
/// Otherwise, the requested IP address is looked up in the DNS server to see if it corresponds to a domain if `fakeIPEnabled` is `true`.
/// Unless there is a good reason not to, any socket shoule connect based on this directly.
public var host: String
/// The requested port.
public let port: Int
/// The rule to use to connect to remote.
public var matchedRule: Rule?
/// Whether If the `requestedHost` is an IP address.
public let fakeIPEnabled: Bool
public var error: Error?
public var errorSource: EventSourceEnum?
public var disconnectedBy: EventSourceEnum?
/// The resolved IP address.
///
/// - note: This will always be real IP address.
public lazy var ipAddress: String = {
[unowned self] in
if self.isIP() {
return self.host
} else {
let ip = Utils.DNS.resolve(self.host)
guard self.fakeIPEnabled else {
return ip
}
guard let dnsServer = DNSServer.currentServer else {
return ip
}
guard let address = IPAddress(fromString: ip) else {
return ip
}
guard dnsServer.isFakeIP(address) else {
return ip
}
guard let session = dnsServer.lookupFakeIP(address) else {
return ip
}
return session.realIP?.presentation ?? ""
}
}()
/// The location of the host.
// public lazy var country: String = {
// [unowned self] in
// guard let c = Utils.GeoIPLookup.Lookup(self.ipAddress) else {
// return ""
// }
// return c
// }()
public init?(host: String, port: Int, fakeIPEnabled: Bool = true) {
self.requestedHost = host
self.port = port
self.fakeIPEnabled = fakeIPEnabled
self.host = host
if fakeIPEnabled {
guard lookupRealIP() else {
return nil
}
}
}
public convenience init?(ipAddress: IPAddress, port: Port, fakeIPEnabled: Bool = true) {
self.init(host: ipAddress.presentation, port: Int(port.value), fakeIPEnabled: fakeIPEnabled)
}
func disconnected(becauseOf error: Error? = nil, by source: EventSourceEnum) {
if disconnectedBy == nil {
self.error = error
if error != nil {
errorSource = source
}
disconnectedBy = source
}
}
fileprivate func lookupRealIP() -> Bool {
/// If custom DNS server is set up.
guard let dnsServer = DNSServer.currentServer else {
return true
}
// Only IPv4 is supported as of now.
guard isIPv4() else {
return true
}
let address = IPAddress(fromString: requestedHost)!
guard dnsServer.isFakeIP(address) else {
return true
}
// Look up fake IP reversely should never fail.
guard let session = dnsServer.lookupFakeIP(address) else {
return false
}
host = session.requestMessage.queries[0].name
ipAddress = session.realIP?.presentation ?? ""
matchedRule = session.matchedRule
// if session.countryCode != nil {
// country = session.countryCode!
// }
return true
}
public func isIPv4() -> Bool {
return Utils.IP.isIPv4(host)
}
public func isIPv6() -> Bool {
return Utils.IP.isIPv6(host)
}
public func isIP() -> Bool {
return isIPv4() || isIPv6()
}
}
extension ConnectSession: CustomStringConvertible {
public var description: String {
if requestedHost != host {
return "<\(type(of: self)) host:\(host) port:\(port) requestedHost:\(requestedHost)>"
} else {
return "<\(type(of: self)) host:\(host) port:\(port)>"
}
}
}

View File

@@ -0,0 +1,200 @@
import Foundation
open class HTTPHeader {
public enum HTTPHeaderError: Error {
case malformedHeader, invalidRequestLine, invalidHeaderField, invalidConnectURL, invalidConnectPort, invalidURL, missingHostField, invalidHostField, invalidHostPort, invalidContentLength, illegalEncoding
}
open var HTTPVersion: String
open var method: String
open var isConnect: Bool = false
open var path: String
open var foundationURL: Foundation.URL?
open var homemadeURL: HTTPURL?
open var host: String
open var port: Int
// just assume that `Content-Length` is given as of now.
// Chunk is not supported yet.
open var contentLength: Int = 0
open var headers: [(String, String)] = []
open var rawHeader: Data?
public init(headerString: String) throws {
let lines = headerString.components(separatedBy: "\r\n")
guard lines.count >= 3 else {
throw HTTPHeaderError.malformedHeader
}
let request = lines[0].components(separatedBy: " ")
guard request.count == 3 else {
throw HTTPHeaderError.invalidRequestLine
}
method = request[0]
path = request[1]
HTTPVersion = request[2]
for line in lines[1..<lines.count-2] {
let header = line.split(separator: ":", maxSplits: 1, omittingEmptySubsequences: false)
guard header.count == 2 else {
throw HTTPHeaderError.invalidHeaderField
}
headers.append((String(header[0]).trimmingCharacters(in: CharacterSet.whitespaces), String(header[1]).trimmingCharacters(in: CharacterSet.whitespaces)))
}
if method.uppercased() == "CONNECT" {
isConnect = true
let urlInfo = path.components(separatedBy: ":")
guard urlInfo.count == 2 else {
throw HTTPHeaderError.invalidConnectURL
}
host = urlInfo[0]
guard let port = Int(urlInfo[1]) else {
throw HTTPHeaderError.invalidConnectPort
}
self.port = port
self.contentLength = 0
} else {
var resolved = false
host = ""
port = 80
if let _url = Foundation.URL(string: path) {
foundationURL = _url
if foundationURL!.host != nil {
host = foundationURL!.host!
port = foundationURL!.port ?? 80
resolved = true
}
} else {
guard let _url = HTTPURL(string: path) else {
throw HTTPHeaderError.invalidURL
}
homemadeURL = _url
if homemadeURL!.host != nil {
host = homemadeURL!.host!
port = homemadeURL!.port ?? 80
resolved = true
}
}
if !resolved {
var url: String = ""
for (key, value) in headers {
if "Host".caseInsensitiveCompare(key) == .orderedSame {
url = value
break
}
}
guard url != "" else {
throw HTTPHeaderError.missingHostField
}
let urlInfo = url.components(separatedBy: ":")
guard urlInfo.count <= 2 else {
throw HTTPHeaderError.invalidHostField
}
if urlInfo.count == 2 {
host = urlInfo[0]
guard let port = Int(urlInfo[1]) else {
throw HTTPHeaderError.invalidHostPort
}
self.port = port
} else {
host = urlInfo[0]
port = 80
}
}
for (key, value) in headers {
if "Content-Length".caseInsensitiveCompare(key) == .orderedSame {
guard let contentLength = Int(value) else {
throw HTTPHeaderError.invalidContentLength
}
self.contentLength = contentLength
break
}
}
}
}
public convenience init(headerData: Data) throws {
guard let headerString = String(data: headerData, encoding: .utf8) else {
throw HTTPHeaderError.illegalEncoding
}
try self.init(headerString: headerString)
rawHeader = headerData
}
open subscript(index: String) -> String? {
get {
for (key, value) in headers {
if index.caseInsensitiveCompare(key) == .orderedSame {
return value
}
}
return nil
}
}
open func toData() -> Data {
return toString().data(using: String.Encoding.utf8)!
}
open func toString() -> String {
var strRep = "\(method) \(path) \(HTTPVersion)\r\n"
for (key, value) in headers {
strRep += "\(key): \(value)\r\n"
}
strRep += "\r\n"
return strRep
}
open func addHeader(_ key: String, value: String) {
headers.append((key, value))
}
open func rewriteToRelativePath() {
if path[path.startIndex] != "/" {
guard let rewrotePath = URL.matchRelativePath(path) else {
return
}
path = rewrotePath
}
}
open func removeHeader(_ key: String) -> String? {
for i in 0..<headers.count {
if headers[i].0.caseInsensitiveCompare(key) == .orderedSame {
let (_, value) = headers.remove(at: i)
return value
}
}
return nil
}
open func removeProxyHeader() {
let ProxyHeader = ["Proxy-Authenticate", "Proxy-Authorization", "Proxy-Connection"]
for header in ProxyHeader {
_ = removeHeader(header)
}
}
struct URL {
// swiftlint:disable:next force_try
static let relativePathRegex = try! NSRegularExpression(pattern: "http.?:\\/\\/.*?(\\/.*)", options: NSRegularExpression.Options.caseInsensitive)
static func matchRelativePath(_ url: String) -> String? {
if let result = relativePathRegex.firstMatch(in: url, options: NSRegularExpression.MatchingOptions(), range: NSRange(location: 0, length: url.count)) {
return (url as NSString).substring(with: result.range(at: 1))
} else {
return nil
}
}
}
}

View File

@@ -0,0 +1,24 @@
import Foundation
public struct Opt {
public static var MAXNWTCPSocketReadDataSize = 128 * 1024
// This is only used in finding the end of HTTP header (as of now). There is no limit on the length of http header, but Apache set it to 8KB
public static var MAXNWTCPScanLength = 8912
public static var DNSFakeIPTTL = 300
public static var DNSPendingSessionLifeTime = 10
public static var UDPSocketActiveTimeout = 300
public static var UDPSocketActiveCheckInterval = 60
public static var MAXHTTPContentBlockLength = 10240
public static var RejectAdapterDefaultDelay = 300
public static var DNSTimeout = 1
public static var forwardReadInterval = 50
}

View File

@@ -0,0 +1,24 @@
import Foundation
/// The HTTP proxy server.
public final class GCDHTTPProxyServer: GCDProxyServer {
/**
Create an instance of HTTP proxy server.
- parameter address: The address of proxy server.
- parameter port: The port of proxy server.
*/
override public init(address: IPAddress?, port: Port) {
super.init(address: address, port: port)
}
/**
Handle the new accepted socket as a HTTP proxy connection.
- parameter socket: The accepted socket.
*/
override public func handleNewGCDSocket(_ socket: GCDTCPSocket) {
let proxySocket = HTTPProxySocket(socket: socket)
didAcceptNewSocket(proxySocket)
}
}

View File

@@ -0,0 +1,61 @@
import Foundation
/// Proxy server which listens on some port by GCDAsyncSocket.
///
/// This shoule be the base class for any concrete implementation of proxy server (e.g., HTTP or SOCKS5) which needs to listen on some port.
open class GCDProxyServer: ProxyServer, GCDAsyncSocketDelegate {
fileprivate var listenSocket: GCDAsyncSocket!
/**
Start the proxy server which creates a GCDAsyncSocket listening on specific port.
- throws: The error occured when starting the proxy server.
*/
override open func start() throws {
try QueueFactory.executeOnQueueSynchronizedly {
listenSocket = GCDAsyncSocket(delegate: self, delegateQueue: QueueFactory.getQueue(), socketQueue: QueueFactory.getQueue())
try listenSocket.accept(onInterface: address?.presentation, port: port.value)
try super.start()
}
}
/**
Stop the proxy server.
*/
override open func stop() {
QueueFactory.executeOnQueueSynchronizedly {
listenSocket?.setDelegate(nil, delegateQueue: nil)
listenSocket?.disconnect()
listenSocket = nil
super.stop()
}
}
/**
Delegate method to handle the newly accepted GCDTCPSocket.
Only this method should be overrided in any concrete implementation of proxy server which listens on some port with GCDAsyncSocket.
- parameter socket: The accepted socket.
*/
open func handleNewGCDSocket(_ socket: GCDTCPSocket) {
}
/**
GCDAsyncSocket delegate callback.
- parameter sock: The listening GCDAsyncSocket.
- parameter newSocket: The accepted new GCDAsyncSocket.
- warning: Do not call this method. This should be marked private but have to be marked public since the `GCDAsyncSocketDelegate` is public.
*/
open func socket(_ sock: GCDAsyncSocket, didAcceptNewSocket newSocket: GCDAsyncSocket) {
let gcdTCPSocket = GCDTCPSocket(socket: newSocket)
handleNewGCDSocket(gcdTCPSocket)
}
public func newSocketQueueForConnection(fromAddress address: Data, on sock: GCDAsyncSocket) -> DispatchQueue? {
return QueueFactory.getQueue()
}
}

View File

@@ -0,0 +1,24 @@
import Foundation
/// The SOCKS5 proxy server.
public final class GCDSOCKS5ProxyServer: GCDProxyServer {
/**
Create an instance of SOCKS5 proxy server.
- parameter address: The address of proxy server.
- parameter port: The port of proxy server.
*/
override public init(address: IPAddress?, port: Port) {
super.init(address: address, port: port)
}
/**
Handle the new accepted socket as a SOCKS5 proxy connection.
- parameter socket: The accepted socket.
*/
override public func handleNewGCDSocket(_ socket: GCDTCPSocket) {
let proxySocket = SOCKS5ProxySocket(socket: socket)
didAcceptNewSocket(proxySocket)
}
}

View File

@@ -0,0 +1,105 @@
import Foundation
/**
The base proxy server class.
This proxy does not listen on any port.
*/
open class ProxyServer: NSObject, TunnelDelegate {
typealias TunnelArray = [Tunnel]
/// The port of proxy server.
public let port: Port
/// The address of proxy server.
public let address: IPAddress?
/// The type of the proxy server.
///
/// This can be set to anything describing the proxy server.
public let type: String
/// The description of proxy server.
open override var description: String {
return "<\(type) address:\(String(describing: address)) port:\(port)>"
}
open var observer: Observer<ProxyServerEvent>?
var tunnels: TunnelArray = []
/**
Create an instance of proxy server.
- parameter address: The address of proxy server.
- parameter port: The port of proxy server.
- warning: If you are using Network Extension, you have to set address or you may not able to connect to the proxy server.
*/
public init(address: IPAddress?, port: Port) {
self.address = address
self.port = port
type = "\(Swift.type(of: self))"
super.init()
self.observer = ObserverFactory.currentFactory?.getObserverForProxyServer(self)
}
/**
Start the proxy server.
- throws: The error occured when starting the proxy server.
*/
open func start() throws {
QueueFactory.executeOnQueueSynchronizedly {
GlobalIntializer.initalize()
self.observer?.signal(.started(self))
}
}
/**
Stop the proxy server.
*/
open func stop() {
QueueFactory.executeOnQueueSynchronizedly {
for tunnel in tunnels {
tunnel.forceClose()
}
observer?.signal(.stopped(self))
}
}
/**
Delegate method when the proxy server accepts a new ProxySocket from local.
When implementing a concrete proxy server, e.g., HTTP proxy server, the server should listen on some port and then wrap the raw socket in a corresponding ProxySocket subclass, then call this method.
- parameter socket: The accepted proxy socket.
*/
func didAcceptNewSocket(_ socket: ProxySocket) {
observer?.signal(.newSocketAccepted(socket, onServer: self))
let tunnel = Tunnel(proxySocket: socket)
tunnel.delegate = self
tunnels.append(tunnel)
tunnel.openTunnel()
}
// MARK: TunnelDelegate implementation
/**
Delegate method when a tunnel closed. The server will remote it internally.
- parameter tunnel: The closed tunnel.
*/
func tunnelDidClose(_ tunnel: Tunnel) {
observer?.signal(.tunnelClosed(tunnel, onServer: self))
guard let index = tunnels.firstIndex(of: tunnel) else {
// things went strange
return
}
tunnels.remove(at: index)
}
}

View File

@@ -0,0 +1,253 @@
import Foundation
/// The TCP socket build upon `GCDAsyncSocket`.
///
/// - warning: This class is not thread-safe.
open class GCDTCPSocket: NSObject, GCDAsyncSocketDelegate, RawTCPSocketProtocol {
fileprivate let socket: GCDAsyncSocket
fileprivate var enableTLS: Bool = false
fileprivate var host: String?
/**
Initailize an instance with `GCDAsyncSocket`.
- parameter socket: The socket object to work with. If this is `nil`, then a new `GCDAsyncSocket` instance is created.
*/
public init(socket: GCDAsyncSocket? = nil) {
if let socket = socket {
self.socket = socket
self.socket.setDelegate(nil, delegateQueue: QueueFactory.getQueue())
} else {
self.socket = GCDAsyncSocket(delegate: nil, delegateQueue: QueueFactory.getQueue(), socketQueue: QueueFactory.getQueue())
}
super.init()
self.socket.synchronouslySetDelegate(self)
}
// MARK: RawTCPSocketProtocol implementation
/// The `RawTCPSocketDelegate` instance.
weak open var delegate: RawTCPSocketDelegate?
/// If the socket is connected.
open var isConnected: Bool {
return !socket.isDisconnected
}
/// The source address.
open var sourceIPAddress: IPAddress? {
guard let localHost = socket.localHost else {
return nil
}
return IPAddress(fromString: localHost)
}
/// The source port.
open var sourcePort: Port? {
return Port(port: socket.localPort)
}
/// The destination address.
///
/// - note: Always returns `nil`.
open var destinationIPAddress: IPAddress? {
return nil
}
/// The destination port.
///
/// - note: Always returns `nil`.
open var destinationPort: Port? {
return nil
}
/**
Connect to remote host.
- parameter host: Remote host.
- parameter port: Remote port.
- parameter enableTLS: Should TLS be enabled.
- parameter tlsSettings: The settings of TLS.
- throws: The error occured when connecting to host.
*/
open func connectTo(host: String, port: Int, enableTLS: Bool = false, tlsSettings: [AnyHashable: Any]? = nil) throws {
self.host = host
try connectTo(host: host, withPort: port)
self.enableTLS = enableTLS
if enableTLS {
startTLSWith(settings: tlsSettings)
}
}
/**
Disconnect the socket.
The socket will disconnect elegantly after any queued writing data are successfully sent.
*/
open func disconnect() {
socket.disconnectAfterWriting()
}
/**
Disconnect the socket immediately.
*/
open func forceDisconnect() {
socket.disconnect()
}
/**
Send data to remote.
- parameter data: Data to send.
- warning: This should only be called after the last write is finished, i.e., `delegate?.didWriteData()` is called.
*/
open func write(data: Data) {
write(data: data, withTimeout: -1)
}
/**
Read data from the socket.
- warning: This should only be called after the last read is finished, i.e., `delegate?.didReadData()` is called.
*/
open func readData() {
socket.readData(withTimeout: -1, tag: 0)
}
/**
Read specific length of data from the socket.
- parameter length: The length of the data to read.
- warning: This should only be called after the last read is finished, i.e., `delegate?.didReadData()` is called.
*/
open func readDataTo(length: Int) {
readDataTo(length: length, withTimeout: -1)
}
/**
Read data until a specific pattern (including the pattern).
- parameter data: The pattern.
- warning: This should only be called after the last read is finished, i.e., `delegate?.didReadData()` is called.
*/
open func readDataTo(data: Data) {
readDataTo(data: data, maxLength: 0)
}
/**
Read data until a specific pattern (including the pattern).
- parameter data: The pattern.
- parameter maxLength: Ignored since `GCDAsyncSocket` does not support this. The max length of data to scan for the pattern.
- warning: This should only be called after the last read is finished, i.e., `delegate?.didReadData()` is called.
*/
open func readDataTo(data: Data, maxLength: Int) {
readDataTo(data: data, withTimeout: -1)
}
// MARK: Other helper methods
/**
Send data to remote.
- parameter data: Data to send.
- parameter timeout: Operation timeout.
- warning: This should only be called after the last write is finished, i.e., `delegate?.didWriteData()` is called.
*/
func write(data: Data, withTimeout timeout: Double) {
guard data.count > 0 else {
QueueFactory.getQueue().async {
self.delegate?.didWrite(data: data, by: self)
}
return
}
socket.write(data, withTimeout: timeout, tag: 0)
}
/**
Read specific length of data from the socket.
- parameter length: The length of the data to read.
- parameter timeout: Operation timeout.
- warning: This should only be called after the last read is finished, i.e., `delegate?.didReadData()` is called.
*/
func readDataTo(length: Int, withTimeout timeout: Double) {
socket.readData(toLength: UInt(length), withTimeout: timeout, tag: 0)
}
/**
Read data until a specific pattern (including the pattern).
- parameter data: The pattern.
- parameter timeout: Operation timeout.
- warning: This should only be called after the last read is finished, i.e., `delegate?.didReadData()` is called.
*/
func readDataTo(data: Data, withTimeout timeout: Double) {
socket.readData(to: data, withTimeout: timeout, tag: 0)
}
/**
Connect to remote host.
- parameter host: Remote host.
- parameter port: Remote port.
- throws: The error occured when connecting to host.
*/
func connectTo(host: String, withPort port: Int) throws {
try socket.connect(toHost: host, onPort: UInt16(port))
}
/**
Secures the connection using SSL/TLS.
- parameter tlsSettings: TLS settings, refer to documents of `GCDAsyncSocket` for detail.
*/
func startTLSWith(settings: [AnyHashable: Any]!) {
if let settings = settings as? [String: NSObject] {
socket.startTLS(ensureSendPeerName(tlsSettings: settings))
} else {
socket.startTLS(ensureSendPeerName(tlsSettings: nil))
}
}
// MARK: Delegate methods for GCDAsyncSocket
open func socket(_ sock: GCDAsyncSocket, didWriteDataWithTag tag: Int) {
delegate?.didWrite(data: nil, by: self)
}
open func socket(_ sock: GCDAsyncSocket, didRead data: Data, withTag tag: Int) {
delegate?.didRead(data: data, from: self)
}
open func socketDidDisconnect(_ socket: GCDAsyncSocket, withError err: Error?) {
delegate?.didDisconnectWith(socket: self)
delegate = nil
socket.setDelegate(nil, delegateQueue: nil)
}
open func socket(_ sock: GCDAsyncSocket, didConnectToHost host: String, port: UInt16) {
if !enableTLS {
delegate?.didConnectWith(socket: self)
}
}
open func socketDidSecure(_ sock: GCDAsyncSocket) {
if enableTLS {
delegate?.didConnectWith(socket: self)
}
}
private func ensureSendPeerName(tlsSettings: [String: NSObject]? = nil) -> [String: NSObject] {
var setting = tlsSettings ?? [:]
guard setting[kCFStreamSSLPeerName as String] == nil else {
return setting
}
setting[kCFStreamSSLPeerName as String] = host! as NSString
return setting
}
}

View File

@@ -0,0 +1,329 @@
import Foundation
import NetworkExtension
/// The TCP socket build upon `NWTCPConnection`.
///
/// - warning: This class is not thread-safe.
public class NWTCPSocket: NSObject, RawTCPSocketProtocol {
private var connection: NWTCPConnection?
private var writePending = false
private var closeAfterWriting = false
private var cancelled = false
private var scanner: StreamScanner!
private var scanning: Bool = false
private var readDataPrefix: Data?
// MARK: RawTCPSocketProtocol implementation
/// The `RawTCPSocketDelegate` instance.
weak open var delegate: RawTCPSocketDelegate?
/// If the socket is connected.
public var isConnected: Bool {
return connection != nil && connection!.state == .connected
}
/// The source address.
///
/// - note: Always returns `nil`.
public var sourceIPAddress: IPAddress? {
return nil
}
/// The source port.
///
/// - note: Always returns `nil`.
public var sourcePort: Port? {
return nil
}
/// The destination address.
///
/// - note: Always returns `nil`.
public var destinationIPAddress: IPAddress? {
return nil
}
/// The destination port.
///
/// - note: Always returns `nil`.
public var destinationPort: Port? {
return nil
}
/**
Connect to remote host.
- parameter host: Remote host.
- parameter port: Remote port.
- parameter enableTLS: Should TLS be enabled.
- parameter tlsSettings: The settings of TLS.
- throws: Never throws.
*/
public func connectTo(host: String, port: Int, enableTLS: Bool, tlsSettings: [AnyHashable: Any]?) throws {
let endpoint = NWHostEndpoint(hostname: host, port: "\(port)")
let tlsParameters = NWTLSParameters()
if let tlsSettings = tlsSettings as? [String: AnyObject] {
tlsParameters.setValuesForKeys(tlsSettings)
}
guard let connection = RawSocketFactory.TunnelProvider?.createTCPConnection(to: endpoint, enableTLS: enableTLS, tlsParameters: tlsParameters, delegate: nil) else {
// This should only happen when the extension is already stopped and `RawSocketFactory.TunnelProvider` is set to `nil`.
return
}
self.connection = connection
connection.addObserver(self, forKeyPath: "state", options: [.initial, .new], context: nil)
}
/**
Disconnect the socket.
The socket will disconnect elegantly after any queued writing data are successfully sent.
*/
public func disconnect() {
cancelled = true
if connection == nil || connection!.state == .cancelled {
delegate?.didDisconnectWith(socket: self)
} else {
closeAfterWriting = true
checkStatus()
}
}
/**
Disconnect the socket immediately.
*/
public func forceDisconnect() {
cancelled = true
if connection == nil || connection!.state == .cancelled {
delegate?.didDisconnectWith(socket: self)
} else {
cancel()
}
}
/**
Send data to remote.
- parameter data: Data to send.
- warning: This should only be called after the last write is finished, i.e., `delegate?.didWriteData()` is called.
*/
public func write(data: Data) {
guard !cancelled else {
return
}
guard data.count > 0 else {
QueueFactory.getQueue().async {
self.delegate?.didWrite(data: data, by: self)
}
return
}
send(data: data)
}
/**
Read data from the socket.
- warning: This should only be called after the last read is finished, i.e., `delegate?.didReadData()` is called.
*/
public func readData() {
guard !cancelled else {
return
}
connection!.readMinimumLength(1, maximumLength: Opt.MAXNWTCPSocketReadDataSize) { data, error in
guard error == nil else {
DDLogError("NWTCPSocket got an error when reading data: \(String(describing: error))")
self.queueCall {
self.disconnect()
}
return
}
self.readCallback(data: data)
}
}
/**
Read specific length of data from the socket.
- parameter length: The length of the data to read.
- warning: This should only be called after the last read is finished, i.e., `delegate?.didReadData()` is called.
*/
public func readDataTo(length: Int) {
guard !cancelled else {
return
}
connection!.readLength(length) { data, error in
guard error == nil else {
DDLogError("NWTCPSocket got an error when reading data: \(String(describing: error))")
self.queueCall {
self.disconnect()
}
return
}
self.readCallback(data: data)
}
}
/**
Read data until a specific pattern (including the pattern).
- parameter data: The pattern.
- warning: This should only be called after the last read is finished, i.e., `delegate?.didReadData()` is called.
*/
public func readDataTo(data: Data) {
readDataTo(data: data, maxLength: 0)
}
// Actually, this method is available as `- (void)readToPattern:(id)arg1 maximumLength:(unsigned int)arg2 completionHandler:(id /* block */)arg3;`
// which is sadly not available in public header for some reason I don't know.
// I don't want to do it myself since This method is not trival to implement and I don't like reinventing the wheel.
// Here is only the most naive version, which may not be the optimal if using with large data blocks.
/**
Read data until a specific pattern (including the pattern).
- parameter data: The pattern.
- parameter maxLength: The max length of data to scan for the pattern.
- warning: This should only be called after the last read is finished, i.e., `delegate?.didReadData()` is called.
*/
public func readDataTo(data: Data, maxLength: Int) {
guard !cancelled else {
return
}
var maxLength = maxLength
if maxLength == 0 {
maxLength = Opt.MAXNWTCPScanLength
}
scanner = StreamScanner(pattern: data, maximumLength: maxLength)
scanning = true
readData()
}
private func queueCall(_ block: @escaping () -> Void) {
QueueFactory.getQueue().async(execute: block)
}
override public func observeValue(forKeyPath keyPath: String?, of object: Any?, change: [NSKeyValueChangeKey : Any]?, context: UnsafeMutableRawPointer?) {
guard keyPath == "state" else {
return
}
switch connection!.state {
case .connected:
queueCall {
self.delegate?.didConnectWith(socket: self)
}
case .disconnected:
cancelled = true
cancel()
case .cancelled:
cancelled = true
queueCall {
let delegate = self.delegate
self.delegate = nil
delegate?.didDisconnectWith(socket: self)
}
default:
break
}
}
private func readCallback(data: Data?) {
guard !cancelled else {
return
}
queueCall {
guard let data = self.consumeReadData(data) else {
// remote read is closed, but this is okay, nothing need to be done, if this socket is read again, then error occurs.
return
}
if self.scanning {
guard let (match, rest) = self.scanner.addAndScan(data) else {
self.readData()
return
}
self.scanner = nil
self.scanning = false
guard let matchData = match else {
// do not find match in the given length, stop now
return
}
self.readDataPrefix = rest
self.delegate?.didRead(data: matchData, from: self)
} else {
self.delegate?.didRead(data: data, from: self)
}
}
}
private func send(data: Data) {
writePending = true
self.connection!.write(data) { error in
self.queueCall {
self.writePending = false
guard error == nil else {
DDLogError("NWTCPSocket got an error when writing data: \(String(describing: error))")
self.disconnect()
return
}
self.delegate?.didWrite(data: data, by: self)
self.checkStatus()
}
}
}
private func consumeReadData(_ data: Data?) -> Data? {
defer {
readDataPrefix = nil
}
if readDataPrefix == nil {
return data
}
if data == nil {
return readDataPrefix
}
var wholeData = readDataPrefix!
wholeData.append(data!)
return wholeData
}
private func cancel() {
connection?.cancel()
}
private func checkStatus() {
if closeAfterWriting && !writePending {
cancel()
}
}
deinit {
guard let connection = connection else {
return
}
connection.removeObserver(self, forKeyPath: "state")
}
}

View File

@@ -0,0 +1,160 @@
import Foundation
import NetworkExtension
/// The delegate protocol of `NWUDPSocket`.
public protocol NWUDPSocketDelegate: class {
/**
Socket did receive data from remote.
- parameter data: The data.
- parameter from: The socket the data is read from.
*/
func didReceive(data: Data, from: NWUDPSocket)
func didCancel(socket: NWUDPSocket)
}
/// The wrapper for NWUDPSession.
///
/// - note: This class is thread-safe.
public class NWUDPSocket: NSObject {
private let session: NWUDPSession
private var pendingWriteData: [Data] = []
private var writing = false
private let queue: DispatchQueue = QueueFactory.getQueue()
private let timer: DispatchSourceTimer
private let timeout: Int
/// The delegate instance.
public weak var delegate: NWUDPSocketDelegate?
/// The time when the last activity happens.
///
/// Since UDP do not have a "close" semantic, this can be an indicator of timeout.
public var lastActive: Date = Date()
/**
Create a new UDP socket connecting to remote.
- parameter host: The host.
- parameter port: The port.
*/
public init?(host: String, port: Int, timeout: Int = Opt.UDPSocketActiveTimeout) {
guard let udpsession = RawSocketFactory.TunnelProvider?.createUDPSession(to: NWHostEndpoint(hostname: host, port: "\(port)"), from: nil) else {
return nil
}
session = udpsession
self.timeout = timeout
timer = DispatchSource.makeTimerSource(queue: queue)
super.init()
timer.schedule(deadline: DispatchTime.now(), repeating: DispatchTimeInterval.seconds(Opt.UDPSocketActiveCheckInterval), leeway: DispatchTimeInterval.seconds(Opt.UDPSocketActiveCheckInterval))
timer.setEventHandler { [weak self] in
self?.queueCall {
self?.checkStatus()
}
}
timer.resume()
session.addObserver(self, forKeyPath: #keyPath(NWUDPSession.state), options: [.new], context: nil)
session.setReadHandler({ [ weak self ] dataArray, error in
self?.queueCall {
guard let sSelf = self else {
return
}
sSelf.updateActivityTimer()
guard error == nil, let dataArray = dataArray else {
DDLogError("Error when reading from remote server. \(error?.localizedDescription ?? "Connection reset")")
return
}
for data in dataArray {
sSelf.delegate?.didReceive(data: data, from: sSelf)
}
}
}, maxDatagrams: 32)
}
/**
Send data to remote.
- parameter data: The data to send.
*/
public func write(data: Data) {
pendingWriteData.append(data)
checkWrite()
}
public func disconnect() {
session.cancel()
timer.cancel()
}
public override func observeValue(forKeyPath keyPath: String?, of object: Any?, change: [NSKeyValueChangeKey : Any]?, context: UnsafeMutableRawPointer?) {
guard keyPath == "state" else {
return
}
switch session.state {
case .cancelled:
queueCall {
self.delegate?.didCancel(socket: self)
}
case .ready:
checkWrite()
default:
break
}
}
private func checkWrite() {
updateActivityTimer()
guard session.state == .ready else {
return
}
guard !writing else {
return
}
guard pendingWriteData.count > 0 else {
return
}
writing = true
session.writeMultipleDatagrams(self.pendingWriteData) {_ in
self.queueCall {
self.writing = false
self.checkWrite()
}
}
self.pendingWriteData.removeAll(keepingCapacity: true)
}
private func updateActivityTimer() {
lastActive = Date()
}
private func checkStatus() {
if timeout > 0 && Date().timeIntervalSince(lastActive) > TimeInterval(timeout) {
disconnect()
}
}
private func queueCall(block: @escaping () -> Void) {
queue.async {
block()
}
}
deinit {
session.removeObserver(self, forKeyPath: #keyPath(NWUDPSession.state))
}
}

View File

@@ -0,0 +1,42 @@
import Foundation
import NetworkExtension
/**
Represents the type of the socket.
- NW: The socket based on `NWTCPConnection`.
- GCD: The socket based on `GCDAsyncSocket`.
*/
public enum SocketBaseType {
case nw, gcd
}
/// Factory to create `RawTCPSocket` based on configuration.
open class RawSocketFactory {
/// Current active `NETunnelProvider` which creates `NWTCPConnection` instance.
///
/// - note: Must set before any connection is created if `NWTCPSocket` or `NWUDPSocket` is used.
public static weak var TunnelProvider: NETunnelProvider?
/**
Return `RawTCPSocket` instance.
- parameter type: The type of the socket.
- returns: The created socket instance.
*/
public static func getRawSocket(_ type: SocketBaseType? = nil) -> RawTCPSocketProtocol {
switch type {
case .some(.nw):
return NWTCPSocket()
case .some(.gcd):
return GCDTCPSocket()
case nil:
if RawSocketFactory.TunnelProvider == nil {
return GCDTCPSocket()
} else {
return NWTCPSocket()
}
}
}
}

View File

@@ -0,0 +1,129 @@
import Foundation
/// The raw socket protocol which represents a TCP socket.
///
/// Any concrete implementation does not need to be thread-safe.
///
/// - warning: It is expected that the instance is accessed on the specific queue only.
public protocol RawTCPSocketProtocol : class {
/// The `RawTCPSocketDelegate` instance.
var delegate: RawTCPSocketDelegate? { get set }
/// If the socket is connected.
var isConnected: Bool { get }
/// The source address.
var sourceIPAddress: IPAddress? { get }
/// The source port.
var sourcePort: Port? { get }
/// The destination address.
var destinationIPAddress: IPAddress? { get }
/// The destination port.
var destinationPort: Port? { get }
/**
Connect to remote host.
- parameter host: Remote host.
- parameter port: Remote port.
- parameter enableTLS: Should TLS be enabled.
- parameter tlsSettings: The settings of TLS.
- throws: The error occured when connecting to host.
*/
func connectTo(host: String, port: Int, enableTLS: Bool, tlsSettings: [AnyHashable: Any]?) throws
/**
Disconnect the socket.
The socket should disconnect elegantly after any queued writing data are successfully sent.
- note: Usually, any concrete implementation should wait until any pending writing data are finished then call `forceDisconnect()`.
*/
func disconnect()
/**
Disconnect the socket immediately.
- note: The socket should disconnect as soon as possible.
*/
func forceDisconnect()
/**
Send data to remote.
- parameter data: Data to send.
- warning: This should only be called after the last write is finished, i.e., `delegate?.didWriteData()` is called.
*/
func write(data: Data)
/**
Read data from the socket.
- warning: This should only be called after the last read is finished, i.e., `delegate?.didReadData()` is called.
*/
func readData()
/**
Read specific length of data from the socket.
- parameter length: The length of the data to read.
- warning: This should only be called after the last read is finished, i.e., `delegate?.didReadData()` is called.
*/
func readDataTo(length: Int)
/**
Read data until a specific pattern (including the pattern).
- parameter data: The pattern.
- warning: This should only be called after the last read is finished, i.e., `delegate?.didReadData()` is called.
*/
func readDataTo(data: Data)
/**
Read data until a specific pattern (including the pattern).
- parameter data: The pattern.
- parameter maxLength: The max length of data to scan for the pattern.
- warning: This should only be called after the last read is finished, i.e., `delegate?.didReadData()` is called.
*/
func readDataTo(data: Data, maxLength: Int)
}
/// The delegate protocol to handle the events from a raw TCP socket.
public protocol RawTCPSocketDelegate: class {
/**
The socket did disconnect.
This should only be called once in the entire lifetime of a socket. After this is called, the delegate will not receive any other events from that socket and the socket should be released.
- parameter socket: The socket which did disconnect.
*/
func didDisconnectWith(socket: RawTCPSocketProtocol)
/**
The socket did read some data.
- parameter data: The data read from the socket.
- parameter from: The socket where the data is read from.
*/
func didRead(data: Data, from: RawTCPSocketProtocol)
/**
The socket did send some data.
- parameter data: The data which have been sent to remote (acknowledged). Note this may not be available since the data may be released to save memory.
- parameter by: The socket where the data is sent out.
*/
func didWrite(data: Data?, by: RawTCPSocketProtocol)
/**
The socket did connect to remote.
- parameter socket: The connected socket.
*/
func didConnectWith(socket: RawTCPSocketProtocol)
}

View File

@@ -0,0 +1,13 @@
import Foundation
open class ResponseGenerator {
public let session: ConnectSession
public init(withSession session: ConnectSession) {
self.session = session
}
open func generateResponse() -> Data {
return Data()
}
}

View File

@@ -0,0 +1,6 @@
import Foundation
open class ResponseGeneratorFactory {
static var HTTPProxyResponseGenerator: ResponseGenerator.Type?
static var SOCKS5ProxyResponseGenerator: ResponseGenerator.Type?
}

View File

@@ -0,0 +1,48 @@
import Foundation
/// The rule matches all DNS and connect sessions.
open class AllRule: Rule {
fileprivate let adapterFactory: AdapterFactory
open override var description: String {
return "<AllRule>"
}
/**
Create a new `AllRule` instance.
- parameter adapterFactory: The factory which builds a corresponding adapter when needed.
*/
public init(adapterFactory: AdapterFactory) {
self.adapterFactory = adapterFactory
super.init()
}
/**
Match DNS session to this rule.
- parameter session: The DNS session to match.
- parameter type: What kind of information is available.
- returns: The result of match.
*/
override open func matchDNS(_ session: DNSSession, type: DNSSessionMatchType) -> DNSSessionMatchResult {
// only return real IP when we connect to remote directly
if let _ = adapterFactory as? DirectAdapterFactory {
return .real
} else {
return .fake
}
}
/**
Match connect session to this rule.
- parameter session: connect session to match.
- returns: The configured adapter.
*/
override open func match(_ session: ConnectSession) -> AdapterFactory? {
return adapterFactory
}
}

View File

@@ -0,0 +1,60 @@
import Foundation
/// The rule matches the request which failed to look up.
open class DNSFailRule: Rule {
fileprivate let adapterFactory: AdapterFactory
open override var description: String {
return "<DNSFailRule>"
}
/**
Create a new `DNSFailRule` instance.
- parameter adapterFactory: The factory which builds a corresponding adapter when needed.
*/
public init(adapterFactory: AdapterFactory) {
self.adapterFactory = adapterFactory
super.init()
}
/**
Match DNS request to this rule.
- parameter session: The DNS session to match.
- parameter type: What kind of information is available.
- returns: The result of match.
*/
override open func matchDNS(_ session: DNSSession, type: DNSSessionMatchType) -> DNSSessionMatchResult {
guard type == .ip else {
return .unknown
}
// only return real IP when we connect to remote directly
if session.realIP == nil {
if let _ = adapterFactory as? DirectAdapterFactory {
return .real
} else {
return .fake
}
} else {
return .pass
}
}
/**
Match connect session to this rule.
- parameter session: connect session to match.
- returns: The configured adapter.
*/
override open func match(_ session: ConnectSession) -> AdapterFactory? {
if session.ipAddress == "" {
return adapterFactory
} else {
return nil
}
}
}

View File

@@ -0,0 +1,13 @@
import Foundation
/**
The result of matching the rule to DNS request.
- Real: The request matches the rule and the connection can be done with a real IP address.
- Fake: The request matches the rule but we need to identify this session when a later connection is fired with an IP address instead of the host domain.
- Unknown: The match type is `DNSSessionMatchType.Domain` but rule needs the resolved IP address.
- Pass: This rule does not match the request.
*/
public enum DNSSessionMatchResult {
case real, fake, unknown, pass
}

View File

@@ -0,0 +1,13 @@
import Foundation
/**
The information available in current round of matching.
Since we want to speed things up, we first match the request without resolving it (`.Domain`). If any rule returns `.Unknown`, we lookup the request and rematches that rule (`.IP`).
- Domain: Only domain information is available.
- IP: The IP address is resolved.
*/
public enum DNSSessionMatchType {
case domain, ip
}

View File

@@ -0,0 +1,16 @@
import Foundation
/// The rule matches every request and returns direct adapter.
///
/// This is equivalent to create an `AllRule` with a `DirectAdapterFactory`.
open class DirectRule: AllRule {
open override var description: String {
return "<DirectRule>"
}
/**
Create a new `DirectRule` instance.
*/
public init() {
super.init(adapterFactory: DirectAdapterFactory())
}
}

View File

@@ -0,0 +1,84 @@
import Foundation
/// The rule matches the host domain to a list of predefined criteria.
open class DomainListRule: Rule {
public enum MatchCriterion {
case regex(NSRegularExpression), prefix(String), suffix(String), keyword(String), complete(String)
func match(_ domain: String) -> Bool {
switch self {
case .regex(let regex):
return regex.firstMatch(in: domain, options: [], range: NSRange(location: 0, length: domain.utf8.count)) != nil
case .prefix(let prefix):
return domain.hasPrefix(prefix)
case .suffix(let suffix):
return domain.hasSuffix(suffix)
case .keyword(let keyword):
return domain.contains(keyword)
case .complete(let match):
return domain == match
}
}
}
fileprivate let adapterFactory: AdapterFactory
open override var description: String {
return "<DomainListRule>"
}
/// The list of criteria to match to.
open var matchCriteria: [MatchCriterion] = []
/**
Create a new `DomainListRule` instance.
- parameter adapterFactory: The factory which builds a corresponding adapter when needed.
- parameter criteria: The list of criteria to match.
*/
public init(adapterFactory: AdapterFactory, criteria: [MatchCriterion]) {
self.adapterFactory = adapterFactory
self.matchCriteria = criteria
}
/**
Match DNS request to this rule.
- parameter session: The DNS session to match.
- parameter type: What kind of information is available.
- returns: The result of match.
*/
override open func matchDNS(_ session: DNSSession, type: DNSSessionMatchType) -> DNSSessionMatchResult {
if matchDomain(session.requestMessage.queries.first!.name) {
if let _ = adapterFactory as? DirectAdapterFactory {
return .real
}
return .fake
}
return .pass
}
/**
Match connect session to this rule.
- parameter session: connect session to match.
- returns: The configured adapter if matched, return `nil` if not matched.
*/
override open func match(_ session: ConnectSession) -> AdapterFactory? {
if matchDomain(session.host) {
return adapterFactory
}
return nil
}
fileprivate func matchDomain(_ domain: String) -> Bool {
for criterion in matchCriteria {
if criterion.match(domain) {
return true
}
}
return false
}
}

View File

@@ -0,0 +1,75 @@
import Foundation
/// The rule matches the ip of the target hsot to a list of IP ranges.
open class IPRangeListRule: Rule {
fileprivate let adapterFactory: AdapterFactory
open override var description: String {
return "<IPRangeList>"
}
/// The list of regular expressions to match to.
open var ranges: [IPRange] = []
/**
Create a new `IPRangeListRule` instance.
- parameter adapterFactory: The factory which builds a corresponding adapter when needed.
- parameter ranges: The list of IP ranges to match. The IP ranges are expressed in CIDR form ("127.0.0.1/8") or range form ("127.0.0.1+16777216").
- throws: The error when parsing the IP range.
*/
public init(adapterFactory: AdapterFactory, ranges: [String]) throws {
self.adapterFactory = adapterFactory
self.ranges = try ranges.map {
let range = try IPRange(withString: $0)
return range
}
}
/**
Match DNS request to this rule.
- parameter session: The DNS session to match.
- parameter type: What kind of information is available.
- returns: The result of match.
*/
override open func matchDNS(_ session: DNSSession, type: DNSSessionMatchType) -> DNSSessionMatchResult {
guard type == .ip else {
return .unknown
}
// Probably we should match all answers?
guard let ip = session.realIP else {
return .pass
}
for range in ranges {
if range.contains(ip: ip) {
return .fake
}
}
return .pass
}
/**
Match connect session to this rule.
- parameter session: connect session to match.
- returns: The configured adapter if matched, return `nil` if not matched.
*/
override open func match(_ session: ConnectSession) -> AdapterFactory? {
guard let ip = IPAddress(fromString: session.ipAddress) else {
return nil
}
for range in ranges {
if range.contains(ip: ip) {
return adapterFactory
}
}
return nil
}
}

View File

@@ -0,0 +1,37 @@
import Foundation
/// The rule defines what to do for DNS requests and connect sessions.
open class Rule: CustomStringConvertible {
open var description: String {
return "<Rule>"
}
/**
Create a new rule.
*/
public init() {
}
/**
Match DNS request to this rule.
- parameter session: The DNS session to match.
- parameter type: What kind of information is available.
- returns: The result of match.
*/
open func matchDNS(_ session: DNSSession, type: DNSSessionMatchType) -> DNSSessionMatchResult {
return .real
}
/**
Match connect session to this rule.
- parameter session: connect session to match.
- returns: The configured adapter if matched, return `nil` if not matched.
*/
open func match(_ session: ConnectSession) -> AdapterFactory? {
return nil
}
}

View File

@@ -0,0 +1,80 @@
import Foundation
/// The class managing rules.
open class RuleManager {
/// The current used `RuleManager`, there is only one manager should be used at a time.
///
/// - note: This should be set before any DNS or connect sessions.
public static var currentManager: RuleManager = RuleManager(fromRules: [], appendDirect: true)
/// The rule list.
var rules: [Rule] = []
open var observer: Observer<RuleMatchEvent>?
/**
Create a new `RuleManager` from the given rules.
- parameter rules: The rules.
- parameter appendDirect: Whether to append a `DirectRule` at the end of the list so any request does not match with any rule go directly.
*/
public init(fromRules rules: [Rule], appendDirect: Bool = false) {
self.rules = []
if appendDirect || self.rules.count == 0 {
self.rules.append(DirectRule())
}
observer = ObserverFactory.currentFactory?.getObserverForRuleManager(self)
}
/**
Match DNS request to all rules.
- parameter session: The DNS session to match.
- parameter type: What kind of information is available.
*/
func matchDNS(_ session: DNSSession, type: DNSSessionMatchType) {
for (i, rule) in rules[session.indexToMatch..<rules.count].enumerated() {
let result = rule.matchDNS(session, type: type)
observer?.signal(.dnsRuleMatched(session, rule: rule, type: type, result: result))
switch result {
case .fake, .real, .unknown:
session.matchedRule = rule
session.matchResult = result
session.indexToMatch = i + session.indexToMatch // add the offset
return
case .pass:
break
}
}
}
/**
Match connect session to all rules.
- parameter session: connect session to match.
- returns: The matched configured adapter.
*/
func match(_ session: ConnectSession) -> AdapterFactory! {
if session.matchedRule != nil {
observer?.signal(.ruleMatched(session, rule: session.matchedRule!))
return session.matchedRule!.match(session)
}
for rule in rules {
if let adapterFactory = rule.match(session) {
observer?.signal(.ruleMatched(session, rule: rule))
session.matchedRule = rule
return adapterFactory
} else {
observer?.signal(.ruleDidNotMatch(session, rule: rule))
}
}
return nil // this should never happens
}
}

View File

@@ -0,0 +1,156 @@
import Foundation
open class AdapterSocket: NSObject, SocketProtocol, RawTCPSocketDelegate {
open var session: ConnectSession!
open var observer: Observer<AdapterSocketEvent>?
open override var description: String {
return "<\(typeName) host:\(session.host) port:\(session.port))>"
}
internal var _cancelled = false
public var isCancelled: Bool {
return _cancelled
}
/**
Connect to remote according to the `ConnectSession`.
- parameter session: The connect session.
*/
open func openSocketWith(session: ConnectSession) {
guard !isCancelled else {
return
}
self.session = session
observer?.signal(.socketOpened(self, withSession: session))
socket?.delegate = self
_status = .connecting
}
deinit {
socket?.delegate = nil
}
// MARK: SocketProtocol Implementation
/// The underlying TCP socket transmitting data.
open var socket: RawTCPSocketProtocol!
/// The delegate instance.
weak open var delegate: SocketDelegate?
var _status: SocketStatus = .invalid
/// The current connection status of the socket.
public var status: SocketStatus {
return _status
}
open var statusDescription: String {
return "\(status)"
}
public init(observe: Bool = true) {
super.init()
if observe {
observer = ObserverFactory.currentFactory?.getObserverForAdapterSocket(self)
}
}
/**
Read data from the socket.
- warning: This should only be called after the last read is finished, i.e., `delegate?.didReadData()` is called.
*/
open func readData() {
guard !isCancelled else {
return
}
socket?.readData()
}
/**
Send data to remote.
- parameter data: Data to send.
- warning: This should only be called after the last write is finished, i.e., `delegate?.didWriteData()` is called.
*/
open func write(data: Data) {
guard !isCancelled else {
return
}
socket?.write(data: data)
}
/**
Disconnect the socket elegantly.
*/
open func disconnect(becauseOf error: Error? = nil) {
_status = .disconnecting
_cancelled = true
session.disconnected(becauseOf: error, by: .adapter)
observer?.signal(.disconnectCalled(self))
socket?.disconnect()
}
/**
Disconnect the socket immediately.
*/
open func forceDisconnect(becauseOf error: Error? = nil) {
_status = .disconnecting
_cancelled = true
session.disconnected(becauseOf: error, by: .adapter)
observer?.signal(.forceDisconnectCalled(self))
socket?.forceDisconnect()
}
// MARK: RawTCPSocketDelegate Protocol Implementation
/**
The socket did disconnect.
- parameter socket: The socket which did disconnect.
*/
open func didDisconnectWith(socket: RawTCPSocketProtocol) {
_status = .closed
observer?.signal(.disconnected(self))
delegate?.didDisconnectWith(socket: self)
}
/**
The socket did read some data.
- parameter data: The data read from the socket.
- parameter from: The socket where the data is read from.
*/
open func didRead(data: Data, from: RawTCPSocketProtocol) {
observer?.signal(.readData(data, on: self))
}
/**
The socket did send some data.
- parameter data: The data which have been sent to remote (acknowledged). Note this may not be available since the data may be released to save memory.
- parameter by: The socket where the data is sent out.
*/
open func didWrite(data: Data?, by: RawTCPSocketProtocol) {
observer?.signal(.wroteData(data, on: self))
}
/**
The socket did connect to remote.
- parameter socket: The connected socket.
*/
open func didConnectWith(socket: RawTCPSocketProtocol) {
_status = .established
observer?.signal(.connected(self))
delegate?.didConnectWith(adapterSocket: self)
}
}

View File

@@ -0,0 +1,48 @@
import Foundation
/// This adapter connects to remote directly.
public class DirectAdapter: AdapterSocket {
/// If this is set to `false`, then the IP address will be resolved by system.
var resolveHost = false
/**
Connect to remote according to the `ConnectSession`.
- parameter session: The connect session.
*/
override public func openSocketWith(session: ConnectSession) {
super.openSocketWith(session: session)
guard !isCancelled else {
return
}
do {
try socket.connectTo(host: session.host, port: Int(session.port), enableTLS: false, tlsSettings: nil)
} catch let error {
observer?.signal(.errorOccured(error, on: self))
disconnect()
}
}
/**
The socket did connect to remote.
- parameter socket: The connected socket.
*/
override public func didConnectWith(socket: RawTCPSocketProtocol) {
super.didConnectWith(socket: socket)
observer?.signal(.readyForForward(self))
delegate?.didBecomeReadyToForwardWith(socket: self)
}
override public func didRead(data: Data, from rawSocket: RawTCPSocketProtocol) {
super.didRead(data: data, from: rawSocket)
delegate?.didRead(data: data, from: self)
}
override public func didWrite(data: Data?, by rawSocket: RawTCPSocketProtocol) {
super.didWrite(data: data, by: rawSocket)
delegate?.didWrite(data: data, by: self)
}
}

View File

@@ -0,0 +1,35 @@
import Foundation
/// The base class of adapter factory.
open class AdapterFactory {
public init() {}
/**
Build an adapter.
- parameter session: The connect session.
- returns: The built adapter.
*/
open func getAdapterFor(session: ConnectSession) -> AdapterSocket {
return getDirectAdapter()
}
/**
Helper method to get a `DirectAdapter`.
- returns: A direct adapter.
*/
public func getDirectAdapter() -> AdapterSocket {
let adapter = DirectAdapter()
adapter.socket = RawSocketFactory.getRawSocket()
return adapter
}
}
/// Factory building direct adapters.
///
/// - note: This is needed since we need to identify direct adapter factory.
public class DirectAdapterFactory: AdapterFactory {
public override init() {}
}

View File

@@ -0,0 +1,27 @@
import Foundation
/// This is a very simple wrapper of a dict of type `[String: AdapterFactory]`.
///
/// Use it as a normal dict.
public class AdapterFactoryManager {
private var factoryDict: [String: AdapterFactory]
public subscript(index: String) -> AdapterFactory? {
get {
if index == "direct" {
return DirectAdapterFactory()
}
return factoryDict[index]
}
set { factoryDict[index] = newValue }
}
/**
Initialize a new factory manager.
- parameter factoryDict: The factory dict.
*/
public init(factoryDict: [String: AdapterFactory]) {
self.factoryDict = factoryDict
}
}

View File

@@ -0,0 +1,11 @@
import Foundation
/// Factory building server adapter which requires authentication.
open class HTTPAuthenticationAdapterFactory: ServerAdapterFactory {
let auth: HTTPAuthentication?
required public init(serverHost: String, serverPort: Int, auth: HTTPAuthentication?) {
self.auth = auth
super.init(serverHost: serverHost, serverPort: serverPort)
}
}

View File

@@ -0,0 +1,21 @@
import Foundation
/// Factory building HTTP adapter.
open class HTTPAdapterFactory: HTTPAuthenticationAdapterFactory {
required public init(serverHost: String, serverPort: Int, auth: HTTPAuthentication?) {
super.init(serverHost: serverHost, serverPort: serverPort, auth: auth)
}
/**
Get a HTTP adapter.
- parameter session: The connect session.
- returns: The built adapter.
*/
override open func getAdapterFor(session: ConnectSession) -> AdapterSocket {
let adapter = HTTPAdapter(serverHost: serverHost, serverPort: serverPort, auth: auth)
adapter.socket = RawSocketFactory.getRawSocket()
return adapter
}
}

View File

@@ -0,0 +1,13 @@
import Foundation
open class RejectAdapterFactory: AdapterFactory {
public let delay: Int
public init(delay: Int = Opt.RejectAdapterDefaultDelay) {
self.delay = delay
}
override open func getAdapterFor(session: ConnectSession) -> AdapterSocket {
return RejectAdapter(delay: delay)
}
}

View File

@@ -0,0 +1,21 @@
import Foundation
/// Factory building SOCKS5 adapter.
open class SOCKS5AdapterFactory: ServerAdapterFactory {
override public init(serverHost: String, serverPort: Int) {
super.init(serverHost: serverHost, serverPort: serverPort)
}
/**
Get a SOCKS5 adapter.
- parameter session: The connect session.
- returns: The built adapter.
*/
override open func getAdapterFor(session: ConnectSession) -> AdapterSocket {
let adapter = SOCKS5Adapter(serverHost: serverHost, serverPort: serverPort)
adapter.socket = RawSocketFactory.getRawSocket()
return adapter
}
}

View File

@@ -0,0 +1,21 @@
import Foundation
/// Factory building secured HTTP (HTTP with SSL) adapter.
open class SecureHTTPAdapterFactory: HTTPAdapterFactory {
required public init(serverHost: String, serverPort: Int, auth: HTTPAuthentication?) {
super.init(serverHost: serverHost, serverPort: serverPort, auth: auth)
}
/**
Get a secured HTTP adapter.
- parameter session: The connect session.
- returns: The built adapter.
*/
override open func getAdapterFor(session: ConnectSession) -> AdapterSocket {
let adapter = SecureHTTPAdapter(serverHost: serverHost, serverPort: serverPort, auth: auth)
adapter.socket = RawSocketFactory.getRawSocket()
return adapter
}
}

View File

@@ -0,0 +1,12 @@
import Foundation
/// Factory building adapter with proxy server host and port.
open class ServerAdapterFactory: AdapterFactory {
let serverHost: String
let serverPort: Int
public init(serverHost: String, serverPort: Int) {
self.serverHost = serverHost
self.serverPort = serverPort
}
}

View File

@@ -0,0 +1,28 @@
//import Foundation
//
///// Factory building Shadowsocks adapter.
//open class ShadowsocksAdapterFactory: ServerAdapterFactory {
// let protocolObfuscaterFactory: ShadowsocksAdapter.ProtocolObfuscater.Factory
// let cryptorFactory: ShadowsocksAdapter.CryptoStreamProcessor.Factory
// let streamObfuscaterFactory: ShadowsocksAdapter.StreamObfuscater.Factory
//
// public init(serverHost: String, serverPort: Int, protocolObfuscaterFactory: ShadowsocksAdapter.ProtocolObfuscater.Factory, cryptorFactory: ShadowsocksAdapter.CryptoStreamProcessor.Factory, streamObfuscaterFactory: ShadowsocksAdapter.StreamObfuscater.Factory) {
// self.protocolObfuscaterFactory = protocolObfuscaterFactory
// self.cryptorFactory = cryptorFactory
// self.streamObfuscaterFactory = streamObfuscaterFactory
// super.init(serverHost: serverHost, serverPort: serverPort)
// }
//
// /**
// Get a Shadowsocks adapter.
//
// - parameter session: The connect session.
//
// - returns: The built adapter.
// */
// override open func getAdapterFor(session: ConnectSession) -> AdapterSocket {
// let adapter = ShadowsocksAdapter(host: serverHost, port: serverPort, protocolObfuscater: protocolObfuscaterFactory.build(), cryptor: cryptorFactory.build(), streamObfuscator: streamObfuscaterFactory.build(for: session))
// adapter.socket = RawSocketFactory.getRawSocket()
// return adapter
// }
//}

View File

@@ -0,0 +1,26 @@
//import Foundation
//
///// Factory building speed adapter.
//open class SpeedAdapterFactory: AdapterFactory {
// open var adapterFactories: [(AdapterFactory, Int)]!
//
// public override init() {}
//
// /**
// Get a speed adapter.
//
// - parameter session: The connect session.
//
// - returns: The built adapter.
// */
// override open func getAdapterFor(session: ConnectSession) -> AdapterSocket {
// let adapters = adapterFactories.map { adapterFactory, delay -> (AdapterSocket, Int) in
// let adapter = adapterFactory.getAdapterFor(session: session)
// adapter.socket = RawSocketFactory.getRawSocket()
// return (adapter, delay)
// }
// let speedAdapter = SpeedAdapter()
// speedAdapter.adapters = adapters
// return speedAdapter
// }
//}

View File

@@ -0,0 +1,110 @@
import Foundation
public enum HTTPAdapterError: Error, CustomStringConvertible {
case invalidURL, serailizationFailure
public var description: String {
switch self {
case .invalidURL:
return "Invalid url when connecting through proxy"
case .serailizationFailure:
return "Failed to serialize HTTP CONNECT header"
}
}
}
/// This adapter connects to remote host through a HTTP proxy.
public class HTTPAdapter: AdapterSocket {
enum HTTPAdapterStatus {
case invalid,
connecting,
readingResponse,
forwarding,
stopped
}
/// The host domain of the HTTP proxy.
let serverHost: String
/// The port of the HTTP proxy.
let serverPort: Int
/// The authentication information for the HTTP proxy.
let auth: HTTPAuthentication?
/// Whether the connection to the proxy should be secured or not.
var secured: Bool
var internalStatus: HTTPAdapterStatus = .invalid
public init(serverHost: String, serverPort: Int, auth: HTTPAuthentication?) {
self.serverHost = serverHost
self.serverPort = serverPort
self.auth = auth
secured = false
super.init()
}
override public func openSocketWith(session: ConnectSession) {
super.openSocketWith(session: session)
guard !isCancelled else {
return
}
do {
internalStatus = .connecting
try socket.connectTo(host: serverHost, port: serverPort, enableTLS: secured, tlsSettings: nil)
} catch {}
}
override public func didConnectWith(socket: RawTCPSocketProtocol) {
super.didConnectWith(socket: socket)
guard let url = URL(string: "\(session.host):\(session.port)") else {
observer?.signal(.errorOccured(HTTPAdapterError.invalidURL, on: self))
disconnect()
return
}
let message = CFHTTPMessageCreateRequest(kCFAllocatorDefault, "CONNECT" as CFString, url as CFURL, kCFHTTPVersion1_1).takeRetainedValue()
if let authData = auth {
CFHTTPMessageSetHeaderFieldValue(message, "Proxy-Authorization" as CFString, authData.authString() as CFString?)
}
CFHTTPMessageSetHeaderFieldValue(message, "Host" as CFString, "\(session.host):\(session.port)" as CFString?)
CFHTTPMessageSetHeaderFieldValue(message, "Content-Length" as CFString, "0" as CFString?)
guard let requestData = CFHTTPMessageCopySerializedMessage(message)?.takeRetainedValue() else {
observer?.signal(.errorOccured(HTTPAdapterError.serailizationFailure, on: self))
disconnect()
return
}
internalStatus = .readingResponse
write(data: requestData as Data)
socket.readDataTo(data: Utils.HTTPData.DoubleCRLF)
}
override public func didRead(data: Data, from socket: RawTCPSocketProtocol) {
super.didRead(data: data, from: socket)
switch internalStatus {
case .readingResponse:
internalStatus = .forwarding
observer?.signal(.readyForForward(self))
delegate?.didBecomeReadyToForwardWith(socket: self)
case .forwarding:
observer?.signal(.readData(data, on: self))
delegate?.didRead(data: data, from: self)
default:
return
}
}
override public func didWrite(data: Data?, by socket: RawTCPSocketProtocol) {
super.didWrite(data: data, by: socket)
if internalStatus == .forwarding {
observer?.signal(.wroteData(data, on: self))
delegate?.didWrite(data: data, by: self)
}
}
}

View File

@@ -0,0 +1,49 @@
import Foundation
public class RejectAdapter: AdapterSocket {
public let delay: Int
public init(delay: Int) {
self.delay = delay
}
override public func openSocketWith(session: ConnectSession) {
super.openSocketWith(session: session)
QueueFactory.getQueue().asyncAfter(deadline: DispatchTime.now() + DispatchTimeInterval.milliseconds(delay)) {
[weak self] in
self?.disconnect()
}
}
/**
Disconnect the socket elegantly.
*/
public override func disconnect(becauseOf error: Error? = nil) {
guard !isCancelled else {
return
}
_cancelled = true
session.disconnected(becauseOf: error, by: .adapter)
observer?.signal(.disconnectCalled(self))
_status = .closed
delegate?.didDisconnectWith(socket: self)
}
/**
Disconnect the socket immediately.
*/
public override func forceDisconnect(becauseOf error: Error? = nil) {
guard !isCancelled else {
return
}
_cancelled = true
session.disconnected(becauseOf: error, by: .adapter)
observer?.signal(.forceDisconnectCalled(self))
_status = .closed
delegate?.didDisconnectWith(socket: self)
}
}

View File

@@ -0,0 +1,112 @@
import Foundation
public class SOCKS5Adapter: AdapterSocket {
enum SOCKS5AdapterStatus {
case invalid,
connecting,
readingMethodResponse,
readingResponseFirstPart,
readingResponseSecondPart,
forwarding
}
public let serverHost: String
public let serverPort: Int
var internalStatus: SOCKS5AdapterStatus = .invalid
let helloData = Data(bytes: UnsafePointer<UInt8>(([0x05, 0x01, 0x00] as [UInt8])), count: 3)
public enum ReadTag: Int {
case methodResponse = -20000, connectResponseFirstPart, connectResponseSecondPart
}
public enum WriteTag: Int {
case open = -21000, connectIPv4, connectIPv6, connectDomainLength, connectPort
}
public init(serverHost: String, serverPort: Int) {
self.serverHost = serverHost
self.serverPort = serverPort
super.init()
}
public override func openSocketWith(session: ConnectSession) {
super.openSocketWith(session: session)
guard !isCancelled else {
return
}
do {
internalStatus = .connecting
try socket.connectTo(host: serverHost, port: serverPort, enableTLS: false, tlsSettings: nil)
} catch {}
}
public override func didConnectWith(socket: RawTCPSocketProtocol) {
super.didConnectWith(socket: socket)
write(data: helloData)
internalStatus = .readingMethodResponse
socket.readDataTo(length: 2)
}
public override func didRead(data: Data, from socket: RawTCPSocketProtocol) {
super.didRead(data: data, from: socket)
switch internalStatus {
case .readingMethodResponse:
var response: [UInt8]
if session.isIPv4() {
response = [0x05, 0x01, 0x00, 0x01]
let address = IPAddress(fromString: session.host)!
response += [UInt8](address.dataInNetworkOrder)
} else if session.isIPv6() {
response = [0x05, 0x01, 0x00, 0x04]
let address = IPAddress(fromString: session.host)!
response += [UInt8](address.dataInNetworkOrder)
} else {
response = [0x05, 0x01, 0x00, 0x03]
response.append(UInt8(session.host.utf8.count))
response += [UInt8](session.host.utf8)
}
let portBytes: [UInt8] = Utils.toByteArray(UInt16(session.port)).reversed()
response.append(contentsOf: portBytes)
write(data: Data(response))
internalStatus = .readingResponseFirstPart
socket.readDataTo(length: 5)
case .readingResponseFirstPart:
var readLength = 0
switch data[3] {
case 1:
readLength = 3 + 2
case 3:
readLength = Int(data[4]) + 2
case 4:
readLength = 15 + 2
default:
break
}
internalStatus = .readingResponseSecondPart
socket.readDataTo(length: readLength)
case .readingResponseSecondPart:
internalStatus = .forwarding
observer?.signal(.readyForForward(self))
delegate?.didBecomeReadyToForwardWith(socket: self)
case .forwarding:
delegate?.didRead(data: data, from: self)
default:
return
}
}
override open func didWrite(data: Data?, by socket: RawTCPSocketProtocol) {
super.didWrite(data: data, by: socket)
if internalStatus == .forwarding {
delegate?.didWrite(data: data, by: self)
}
}
}

View File

@@ -0,0 +1,9 @@
import Foundation
/// This adapter connects to remote host through a HTTP proxy with SSL.
public class SecureHTTPAdapter: HTTPAdapter {
override public init(serverHost: String, serverPort: Int, auth: HTTPAuthentication?) {
super.init(serverHost: serverHost, serverPort: serverPort, auth: auth)
secured = true
}
}

View File

@@ -0,0 +1,133 @@
//import Foundation
//
//extension ShadowsocksAdapter {
// public class CryptoStreamProcessor {
// public class Factory {
// let password: String
// let algorithm: CryptoAlgorithm
// let key: Data
//
// public init(password: String, algorithm: CryptoAlgorithm) {
// self.password = password
// self.algorithm = algorithm
// key = CryptoHelper.getKey(password, methodType: algorithm)
// }
//
// public func build() -> CryptoStreamProcessor {
// return CryptoStreamProcessor(key: key, algorithm: algorithm)
// }
// }
//
// public weak var inputStreamProcessor: StreamObfuscater.StreamObfuscaterBase!
// public weak var outputStreamProcessor: ProtocolObfuscater.ProtocolObfuscaterBase!
//
// var readIV: Data!
// let key: Data
// let algorithm: CryptoAlgorithm
//
// var sendKey = false
//
// var buffer = Buffer(capacity: 0)
//
// lazy var writeIV: Data = {
// [unowned self] in
// CryptoHelper.getIV(self.algorithm)
// }()
// lazy var ivLength: Int = {
// [unowned self] in
// CryptoHelper.getIVLength(self.algorithm)
// }()
// lazy var encryptor: StreamCryptoProtocol = {
// [unowned self] in
// self.getCrypto(.encrypt)
// }()
// lazy var decryptor: StreamCryptoProtocol = {
// [unowned self] in
// self.getCrypto(.decrypt)
// }()
//
// init(key: Data, algorithm: CryptoAlgorithm) {
// self.key = key
// self.algorithm = algorithm
// }
//
// func encrypt(data: inout Data) {
// return encryptor.update(&data)
// }
//
// func decrypt(data: inout Data) {
// return decryptor.update(&data)
// }
//
// public func input(data: Data) throws {
// var data = data
//
// if readIV == nil {
// buffer.append(data: data)
// readIV = buffer.get(length: ivLength)
// guard readIV != nil else {
// try inputStreamProcessor!.input(data: Data())
// return
// }
//
// data = buffer.get() ?? Data()
// buffer.release()
// }
//
// decrypt(data: &data)
// try inputStreamProcessor!.input(data: data)
// }
//
// public func output(data: Data) {
// var data = data
// encrypt(data: &data)
// if sendKey {
// return outputStreamProcessor!.output(data: data)
// } else {
// sendKey = true
// var out = Data(capacity: data.count + writeIV.count)
// out.append(writeIV)
// out.append(data)
//
// return outputStreamProcessor!.output(data: out)
// }
// }
//
// private func getCrypto(_ operation: CryptoOperation) -> StreamCryptoProtocol {
// switch algorithm {
// case .AES128CFB, .AES192CFB, .AES256CFB:
// switch operation {
// case .decrypt:
// return CCCrypto(operation: .decrypt, mode: .cfb, algorithm: .aes, initialVector: readIV, key: key)
// case .encrypt:
// return CCCrypto(operation: .encrypt, mode: .cfb, algorithm: .aes, initialVector: writeIV, key: key)
// }
// case .CHACHA20:
// switch operation {
// case .decrypt:
// return SodiumStreamCrypto(key: key, iv: readIV, algorithm: .chacha20)
// case .encrypt:
// return SodiumStreamCrypto(key: key, iv: writeIV, algorithm: .chacha20)
// }
// case .SALSA20:
// switch operation {
// case .decrypt:
// return SodiumStreamCrypto(key: key, iv: readIV, algorithm: .salsa20)
// case .encrypt:
// return SodiumStreamCrypto(key: key, iv: writeIV, algorithm: .salsa20)
// }
// case .RC4MD5:
// var combinedKey = Data(capacity: key.count + ivLength)
// combinedKey.append(key)
// switch operation {
// case .decrypt:
// combinedKey.append(readIV)
// return CCCrypto(operation: .decrypt, mode: .rc4, algorithm: .rc4, initialVector: nil, key: MD5Hash.final(combinedKey))
// case .encrypt:
// combinedKey.append(writeIV)
// return CCCrypto(operation: .encrypt, mode: .rc4, algorithm: .rc4, initialVector: nil, key: MD5Hash.final(combinedKey))
// }
// }
// }
// }
//}

View File

@@ -0,0 +1,371 @@
//import Foundation
//
//extension ShadowsocksAdapter {
// public struct ProtocolObfuscater {
// public class Factory {
// public init() {}
//
// public func build() -> ProtocolObfuscaterBase {
// return ProtocolObfuscaterBase()
// }
// }
//
// public class ProtocolObfuscaterBase {
// public weak var inputStreamProcessor: CryptoStreamProcessor!
// public weak var outputStreamProcessor: ShadowsocksAdapter!
//
// public func start() {}
// public func input(data: Data) throws {}
// public func output(data: Data) {}
//
// public func didWrite() {}
// }
//
// public class OriginProtocolObfuscater: ProtocolObfuscaterBase {
//
// public class Factory: ProtocolObfuscater.Factory {
// public override init() {}
//
// public override func build() -> ShadowsocksAdapter.ProtocolObfuscater.ProtocolObfuscaterBase {
// return OriginProtocolObfuscater()
// }
// }
//
// public override func start() {
// outputStreamProcessor.becomeReadyToForward()
// }
//
// public override func input(data: Data) throws {
// try inputStreamProcessor.input(data: data)
// }
//
// public override func output(data: Data) {
// outputStreamProcessor.output(data: data)
// }
// }
//
// public class HTTPProtocolObfuscater: ProtocolObfuscaterBase {
//
// public class Factory: ProtocolObfuscater.Factory {
// let method: String
// let hosts: [String]
// let customHeader: String?
//
// public init(method: String = "GET", hosts: [String], customHeader: String?) {
// self.method = method
// self.hosts = hosts
// self.customHeader = customHeader
// }
//
// public override func build() -> ShadowsocksAdapter.ProtocolObfuscater.ProtocolObfuscaterBase {
// return HTTPProtocolObfuscater(method: method, hosts: hosts, customHeader: customHeader)
// }
// }
//
// static let headerLength = 30
// static let userAgent = ["Mozilla/5.0 (Windows NT 6.3; WOW64; rv:40.0) Gecko/20100101 Firefox/40.0",
// "Mozilla/5.0 (Windows NT 6.3; WOW64; rv:40.0) Gecko/20100101 Firefox/44.0",
// "Mozilla/5.0 (Windows NT 6.1) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/41.0.2228.0 Safari/537.36",
// "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/535.11 (KHTML, like Gecko) Ubuntu/11.10 Chromium/27.0.1453.93 Chrome/27.0.1453.93 Safari/537.36",
// "Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:35.0) Gecko/20100101 Firefox/35.0",
// "Mozilla/5.0 (compatible; WOW64; MSIE 10.0; Windows NT 6.2)",
// "Mozilla/5.0 (Windows; U; Windows NT 6.1; en-US) AppleWebKit/533.20.25 (KHTML, like Gecko) Version/5.0.4 Safari/533.20.27",
// "Mozilla/4.0 (compatible; MSIE 7.0; Windows NT 6.3; Trident/7.0; .NET4.0E; .NET4.0C)",
// "Mozilla/5.0 (Windows NT 6.3; Trident/7.0; rv:11.0) like Gecko",
// "Mozilla/5.0 (Linux; Android 4.4; Nexus 5 Build/BuildID) AppleWebKit/537.36 (KHTML, like Gecko) Version/4.0 Chrome/30.0.0.0 Mobile Safari/537.36",
// "Mozilla/5.0 (iPad; CPU OS 5_0 like Mac OS X) AppleWebKit/534.46 (KHTML, like Gecko) Version/5.1 Mobile/9A334 Safari/7534.48.3",
// "Mozilla/5.0 (iPhone; CPU iPhone OS 5_0 like Mac OS X) AppleWebKit/534.46 (KHTML, like Gecko) Version/5.1 Mobile/9A334 Safari/7534.48.3"]
//
// let method: String
// let hosts: [String]
// let customHeader: String?
//
// var readingFakeHeader = false
// var sendHeader = false
// var remaining = false
//
// var buffer = Buffer(capacity: 8192)
//
// public init(method: String = "GET", hosts: [String], customHeader: String?) {
// self.method = method
// self.hosts = hosts
// self.customHeader = customHeader
// }
//
// private func generateHeader(encapsulating data: Data) -> String {
// let ind = Int(arc4random_uniform(UInt32(hosts.count)))
// let host = outputStreamProcessor.port == 80 ? hosts[ind] : "\(hosts[ind]):\(outputStreamProcessor.port)"
// var header = "\(method) /\(hexlify(data: data)) HTTP/1.1\r\nHost: \(host)\r\n"
// if let customHeader = customHeader {
// header += customHeader
// } else {
// let ind = Int(arc4random_uniform(UInt32(HTTPProtocolObfuscater.userAgent.count)))
// header += "User-Agent: \(HTTPProtocolObfuscater.userAgent[ind])\r\nAccept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8\r\nAccept-Language: en-US,en;q=0.8\r\nAccept-Encoding: gzip, deflate\r\nDNT: 1\r\nConnection: keep-alive"
// }
// header += "\r\n\r\n"
// return header
// }
//
// private func hexlify(data: Data) -> String {
// var result = ""
// for i in data {
// result = result.appendingFormat("%%%02x", i)
// }
// return result
// }
//
// public override func start() {
// readingFakeHeader = true
// outputStreamProcessor.becomeReadyToForward()
// }
//
// public override func input(data: Data) throws {
// if readingFakeHeader {
// buffer.append(data: data)
// if buffer.get(to: Utils.HTTPData.DoubleCRLF) != nil {
// readingFakeHeader = false
// if let remainData = buffer.get() {
// try inputStreamProcessor.input(data: remainData)
// return
// }
// }
// try inputStreamProcessor.input(data: Data())
// return
// }
//
// try inputStreamProcessor.input(data: data)
// }
//
// public override func output(data: Data) {
// if sendHeader {
// outputStreamProcessor.output(data: data)
// } else {
// var fakeRequestDataLength = inputStreamProcessor.key.count + HTTPProtocolObfuscater.headerLength
// if data.count - fakeRequestDataLength > 64 {
// fakeRequestDataLength += Int(arc4random_uniform(64))
// } else {
// fakeRequestDataLength = data.count
// }
//
// var outputData = generateHeader(encapsulating: data.subdata(in: 0 ..< fakeRequestDataLength)).data(using: .utf8)!
// outputData.append(data.subdata(in: fakeRequestDataLength ..< data.count))
// sendHeader = true
// outputStreamProcessor.output(data: outputData)
// }
// }
// }
//
// public class TLSProtocolObfuscater: ProtocolObfuscaterBase {
//
// public class Factory: ProtocolObfuscater.Factory {
// let hosts: [String]
//
// public init(hosts: [String]) {
// self.hosts = hosts
// }
//
// public override func build() -> ShadowsocksAdapter.ProtocolObfuscater.ProtocolObfuscaterBase {
// return TLSProtocolObfuscater(hosts: hosts)
// }
// }
//
// let hosts: [String]
// let clientID: Data = {
// var id = Data(count: 32)
// Utils.Random.fill(data: &id)
// return id
// }()
//
// private var status = 0
//
// private var buffer = Buffer(capacity: 1024)
//
// init(hosts: [String]) {
// self.hosts = hosts
// }
//
// public override func start() {
// handleStatus0()
// outputStreamProcessor.socket.readDataTo(length: 129)
// }
//
// public override func input(data: Data) throws {
// switch status {
// case 8:
// try handleInput(data: data)
// case 1:
// outputStreamProcessor.becomeReadyToForward()
// default:
// break
// }
// }
//
// public override func output(data: Data) {
// switch status {
// case 8:
// handleStatus8(data: data)
// return
// case 1:
// handleStatus1(data: data)
// return
// default:
// break
// }
// }
//
// private func authData() -> Data {
// var time = UInt32(Date.init().timeIntervalSince1970).bigEndian
// var output = Data(count: 32)
// var key = inputStreamProcessor.key
// key.append(clientID)
//
// withUnsafeBytes(of: &time) {
// output.replaceSubrange(0 ..< 4, with: $0)
// }
//
// Utils.Random.fill(data: &output, from: 4, length: 18)
// output.withUnsafeBytes {
// output.replaceSubrange(22 ..< 32, with: HMAC.final(value: $0.baseAddress!, length: 22, algorithm: .SHA1, key: key).subdata(in: 0..<10))
// }
// return output
// }
//
// private func pack(data: Data) -> Data {
// var output = Data()
// var left = data.count
// while left > 0 {
// let blockSize = UInt16(min(Int(arc4random_uniform(UInt32(UInt16.max))) % 4096 + 100, left))
// var blockSizeBE = blockSize.bigEndian
// output.append(contentsOf: [0x17, 0x03, 0x03])
// withUnsafeBytes(of: &blockSizeBE) {
// output.append($0.baseAddress!.assumingMemoryBound(to: UInt8.self), count: $0.count)
// }
// output.append(data.subdata(in: data.count - left ..< data.count - left + Int(blockSize)))
// left -= Int(blockSize)
// }
// return output
// }
//
// private func handleStatus8(data: Data) {
// outputStreamProcessor.output(data: pack(data: data))
// }
//
// private func handleStatus0() {
// status = 1
//
// var outData = Data()
// outData.append(contentsOf: [0x03, 0x03])
// outData.append(authData())
// outData.append(0x20)
// outData.append(clientID)
// outData.append(contentsOf: [0x00, 0x1c, 0xc0, 0x2b, 0xc0, 0x2f, 0xcc, 0xa9, 0xcc, 0xa8, 0xcc, 0x14, 0xcc, 0x13, 0xc0, 0x0a, 0xc0, 0x14, 0xc0, 0x09, 0xc0, 0x13, 0x00, 0x9c, 0x00, 0x35, 0x00, 0x2f, 0x00, 0x0a])
// outData.append("0100".data(using: .utf8)!)
//
// var extData = Data()
// extData.append(contentsOf: [0xff, 0x01, 0x00, 0x01, 0x00])
// let hostData = hosts[Int(arc4random_uniform(UInt32(hosts.count)))].data(using: .utf8)!
//
// var sniData = Data(capacity: hosts.count + 2 + 1 + 2 + 2 + 2)
//
// sniData.append(contentsOf: [0x00, 0x00])
//
// var _lenBE = UInt16(hostData.count + 5).bigEndian
// withUnsafeBytes(of: &_lenBE) {
// sniData.append($0.baseAddress!.assumingMemoryBound(to: UInt8.self), count: $0.count)
// }
//
// _lenBE = UInt16(hostData.count + 3).bigEndian
// withUnsafeBytes(of: &_lenBE) {
// sniData.append($0.baseAddress!.assumingMemoryBound(to: UInt8.self), count: $0.count)
// }
//
// sniData.append(0x00)
//
// _lenBE = UInt16(hostData.count).bigEndian
// withUnsafeBytes(of: &_lenBE) {
// sniData.append($0.baseAddress!.assumingMemoryBound(to: UInt8.self), count: $0.count)
// }
//
// sniData.append(hostData)
//
// extData.append(sniData)
//
// extData.append(contentsOf: [0x00, 0x17, 0x00, 0x00, 0x00, 0x23, 0x00, 0xd0])
//
// var randomData = Data(count: 208)
// Utils.Random.fill(data: &randomData)
// extData.append(randomData)
//
// extData.append(contentsOf: [0x00, 0x0d, 0x00, 0x16, 0x00, 0x14, 0x06, 0x01, 0x06, 0x03, 0x05, 0x01, 0x05, 0x03, 0x04, 0x01, 0x04, 0x03, 0x03, 0x01, 0x03, 0x03, 0x02, 0x01, 0x02, 0x03])
// extData.append(contentsOf: [0x00, 0x05, 0x00, 0x05, 0x01, 0x00, 0x00, 0x00, 0x00])
// extData.append(contentsOf: [0x00, 0x12, 0x00, 0x00])
// extData.append(contentsOf: [0x75, 0x50, 0x00, 0x00])
// extData.append(contentsOf: [0x00, 0x0b, 0x00, 0x02, 0x01, 0x00])
// extData.append(contentsOf: [0x00, 0x0a, 0x00, 0x06, 0x00, 0x04, 0x00, 0x17, 0x00, 0x18])
//
// _lenBE = UInt16(extData.count).bigEndian
// withUnsafeBytes(of: &_lenBE) {
// outData.append($0.baseAddress!.assumingMemoryBound(to: UInt8.self), count: $0.count)
// }
// outData.append(extData)
//
// var outputData = Data(capacity: outData.count + 9)
// outputData.append(contentsOf: [0x16, 0x03, 0x01])
// _lenBE = UInt16(outData.count + 4).bigEndian
// withUnsafeBytes(of: &_lenBE) {
// outputData.append($0.baseAddress!.assumingMemoryBound(to: UInt8.self), count: $0.count)
// }
// outputData.append(contentsOf: [0x01, 0x00])
// _lenBE = UInt16(outData.count).bigEndian
// withUnsafeBytes(of: &_lenBE) {
// outputData.append($0.baseAddress!.assumingMemoryBound(to: UInt8.self), count: $0.count)
// }
// outputData.append(outData)
// outputStreamProcessor.output(data: outputData)
// }
//
// private func handleStatus1(data: Data) {
// status = 8
//
// var outputData = Data()
// outputData.append(contentsOf: [0x14, 0x03, 0x03, 0x00, 0x01, 0x01, 0x16, 0x03, 0x03, 0x00, 0x20])
// var random = Data(count: 22)
// Utils.Random.fill(data: &random)
// outputData.append(random)
//
// var key = inputStreamProcessor.key
// key.append(clientID)
// outputData.withUnsafeBytes {
// outputData.append(HMAC.final(value: $0.baseAddress!, length: outputData.count, algorithm: .SHA1, key: key).subdata(in: 0..<10))
// }
//
// outputData.append(pack(data: data))
//
// outputStreamProcessor.output(data: outputData)
// }
//
// private func handleInput(data: Data) throws {
// buffer.append(data: data)
// var unpackedData = Data()
// while buffer.left > 5 {
// buffer.skip(3)
// var length: Int = 0
// buffer.withUnsafeBytes { (ptr: UnsafePointer<UInt16>) in
// length = Int(ptr.pointee.byteSwapped)
// }
// buffer.skip(2)
// if buffer.left >= length {
// unpackedData.append(buffer.get(length: length)!)
// continue
// } else {
// buffer.setBack(length: 5)
// break
// }
// }
// buffer.squeeze()
// try inputStreamProcessor.input(data: unpackedData)
// }
// }
//
// }
//}

View File

@@ -0,0 +1,112 @@
//import Foundation
//import CommonCrypto
//
///// This adapter connects to remote through Shadowsocks proxy.
//public class ShadowsocksAdapter: AdapterSocket {
// enum ShadowsocksAdapterStatus {
// case invalid,
// connecting,
// connected,
// forwarding,
// stopped
// }
//
// enum EncryptMethod: String {
// case AES128CFB = "AES-128-CFB", AES192CFB = "AES-192-CFB", AES256CFB = "AES-256-CFB"
//
// static let allValues: [EncryptMethod] = [.AES128CFB, .AES192CFB, .AES256CFB]
// }
//
// public let host: String
// public let port: Int
//
// var internalStatus: ShadowsocksAdapterStatus = .invalid
//
// private let protocolObfuscater: ProtocolObfuscater.ProtocolObfuscaterBase
// private let cryptor: CryptoStreamProcessor
// private let streamObfuscator: StreamObfuscater.StreamObfuscaterBase
//
// public init(host: String, port: Int, protocolObfuscater: ProtocolObfuscater.ProtocolObfuscaterBase, cryptor: CryptoStreamProcessor, streamObfuscator: StreamObfuscater.StreamObfuscaterBase) {
// self.host = host
// self.port = port
// self.protocolObfuscater = protocolObfuscater
// self.cryptor = cryptor
// self.streamObfuscator = streamObfuscator
//
// super.init()
//
// protocolObfuscater.inputStreamProcessor = cryptor
// protocolObfuscater.outputStreamProcessor = self
//
// cryptor.inputStreamProcessor = streamObfuscator
// cryptor.outputStreamProcessor = protocolObfuscater
//
// streamObfuscator.inputStreamProcessor = self
// streamObfuscator.outputStreamProcessor = cryptor
// }
//
// override public func openSocketWith(session: ConnectSession) {
// super.openSocketWith(session: session)
//
// do {
// internalStatus = .connecting
// try socket.connectTo(host: host, port: port, enableTLS: false, tlsSettings: nil)
// } catch let error {
// observer?.signal(.errorOccured(error, on: self))
// disconnect()
// }
// }
//
// override public func didConnectWith(socket: RawTCPSocketProtocol) {
// super.didConnectWith(socket: socket)
//
// internalStatus = .connected
//
// protocolObfuscater.start()
// }
//
// override public func didRead(data: Data, from socket: RawTCPSocketProtocol) {
// super.didRead(data: data, from: socket)
//
// do {
// try protocolObfuscater.input(data: data)
// } catch {
// disconnect()
// }
// }
//
// public override func write(data: Data) {
// streamObfuscator.output(data: data)
// }
//
// public func write(rawData: Data) {
// super.write(data: rawData)
// }
//
// public func input(data: Data) {
// delegate?.didRead(data: data, from: self)
// }
//
// public func output(data: Data) {
// write(rawData: data)
// }
//
// override public func didWrite(data: Data?, by socket: RawTCPSocketProtocol) {
// super.didWrite(data: data, by: socket)
//
// protocolObfuscater.didWrite()
//
// switch internalStatus {
// case .forwarding:
// delegate?.didWrite(data: data, by: self)
// default:
// return
// }
// }
//
// func becomeReadyToForward() {
// internalStatus = .forwarding
// observer?.signal(.readyForForward(self))
// delegate?.didBecomeReadyToForwardWith(socket: self)
// }
//}

View File

@@ -0,0 +1,167 @@
//import Foundation
//
//extension ShadowsocksAdapter {
// public struct StreamObfuscater {
// public class Factory {
// public init() {}
//
// public func build(for session: ConnectSession) -> StreamObfuscaterBase {
// return StreamObfuscaterBase(for: session)
// }
// }
//
// public class StreamObfuscaterBase {
// public weak var inputStreamProcessor: ShadowsocksAdapter!
// private weak var _outputStreamProcessor: CryptoStreamProcessor!
// public var outputStreamProcessor: CryptoStreamProcessor! {
// get {
// return _outputStreamProcessor
// }
// set {
// _outputStreamProcessor = newValue
// key = _outputStreamProcessor?.key
// writeIV = _outputStreamProcessor?.writeIV
// }
// }
//
// public var key: Data?
// public var writeIV: Data?
//
// let session: ConnectSession
//
// init(for session: ConnectSession) {
// self.session = session
// }
//
// func output(data: Data) {}
// func input(data: Data) throws {}
// }
//
// public class OriginStreamObfuscater: StreamObfuscaterBase {
// public class Factory: StreamObfuscater.Factory {
// public override init() {}
//
// public override func build(for session: ConnectSession) -> ShadowsocksAdapter.StreamObfuscater.StreamObfuscaterBase {
// return OriginStreamObfuscater(for: session)
// }
// }
//
// private var requestSend = false
//
// private func requestData(withData data: Data) -> Data {
// let hostLength = session.host.utf8.count
// let length = 1 + 1 + hostLength + 2 + data.count
// var response = Data(count: length)
// response[0] = 3
// response[1] = UInt8(hostLength)
// response.replaceSubrange(2..<2+hostLength, with: session.host.utf8)
// var beport = UInt16(session.port).bigEndian
// withUnsafeBytes(of: &beport) {
// response.replaceSubrange(2+hostLength..<4+hostLength, with: $0)
// }
// response.replaceSubrange(4+hostLength..<length, with: data)
// return response
// }
//
// public override func input(data: Data) throws {
// inputStreamProcessor!.input(data: data)
// }
//
// public override func output(data: Data) {
// if requestSend {
// return outputStreamProcessor!.output(data: data)
// } else {
// requestSend = true
// return outputStreamProcessor!.output(data: requestData(withData: data))
// }
// }
// }
//
// public class OTAStreamObfuscater: StreamObfuscaterBase {
// public class Factory: StreamObfuscater.Factory {
// public override init() {}
//
// public override func build(for session: ConnectSession) -> ShadowsocksAdapter.StreamObfuscater.StreamObfuscaterBase {
// return OTAStreamObfuscater(for: session)
// }
// }
//
// private var count: UInt32 = 0
//
// private let DATA_BLOCK_SIZE = 0xFFFF - 12
//
// private var requestSend = false
//
// private func requestData() -> Data {
// var response: [UInt8] = [0x13]
// response.append(UInt8(session.host.utf8.count))
// response += [UInt8](session.host.utf8)
// response += [UInt8](Utils.toByteArray(UInt16(session.port)).reversed())
// var responseData = Data(bytes: UnsafePointer<UInt8>(response), count: response.count)
// var keyiv = Data(count: key!.count + writeIV!.count)
//
// keyiv.replaceSubrange(0..<writeIV!.count, with: writeIV!)
// keyiv.replaceSubrange(writeIV!.count..<writeIV!.count + key!.count, with: key!)
// responseData.append(HMAC.final(value: responseData, algorithm: .SHA1, key: keyiv).subdata(in: 0..<10))
// return responseData
// }
//
// public override func input(data: Data) throws {
// inputStreamProcessor!.input(data: data)
// }
//
// public override func output(data: Data) {
// let fullBlockCount = data.count / DATA_BLOCK_SIZE
// var outputSize = fullBlockCount * (DATA_BLOCK_SIZE + 10 + 2)
// if data.count > fullBlockCount * DATA_BLOCK_SIZE {
// outputSize += data.count - fullBlockCount * DATA_BLOCK_SIZE + 10 + 2
// }
//
// let _requestData: Data = requestData()
// if !requestSend {
// outputSize += _requestData.count
// }
//
// var outputData = Data(count: outputSize)
// var outputOffset = 0
// var dataOffset = 0
//
// if !requestSend {
// requestSend = true
// outputData.replaceSubrange(0..<_requestData.count, with: _requestData)
// outputOffset += _requestData.count
// }
//
// while outputOffset != outputSize {
// let blockLength = min(data.count - dataOffset, DATA_BLOCK_SIZE)
// var len = UInt16(blockLength).bigEndian
// withUnsafeBytes(of: &len) {
// outputData.replaceSubrange(outputOffset..<outputOffset+2, with: $0)
// }
//
// var kc = Data(count: writeIV!.count + MemoryLayout.size(ofValue: count))
// kc.replaceSubrange(0..<writeIV!.count, with: writeIV!)
// var c = count.bigEndian
// let ms = MemoryLayout.size(ofValue: c)
// withUnsafeBytes(of: &c) {
// kc.replaceSubrange(writeIV!.count..<writeIV!.count+ms, with: $0)
// }
//
// data.withUnsafeBytes {
// outputData.replaceSubrange(outputOffset+2..<outputOffset+12, with: HMAC.final(value: $0.baseAddress!.advanced(by: dataOffset), length: blockLength, algorithm: .SHA1, key: kc).subdata(in: 0..<10))
// }
//
// data.withUnsafeBytes {
// outputData.replaceSubrange(outputOffset+12..<outputOffset+12+blockLength, with: $0.baseAddress!.advanced(by: dataOffset), count: blockLength)
// }
//
// count += 1
// outputOffset += 12 + blockLength
// dataOffset += blockLength
// }
//
// return outputStreamProcessor!.output(data: outputData)
// }
// }
// }
//}

View File

@@ -0,0 +1,113 @@
import Foundation
/// This class just forwards data directly.
/// - note: It is designed to work with tun2socks only.
public class DirectProxySocket: ProxySocket {
enum DirectProxyReadStatus: CustomStringConvertible {
case invalid,
forwarding,
stopped
var description: String {
switch self {
case .invalid:
return "invalid"
case .forwarding:
return "forwarding"
case .stopped:
return "stopped"
}
}
}
enum DirectProxyWriteStatus {
case invalid,
forwarding,
stopped
var description: String {
switch self {
case .invalid:
return "invalid"
case .forwarding:
return "forwarding"
case .stopped:
return "stopped"
}
}
}
private var readStatus: DirectProxyReadStatus = .invalid
private var writeStatus: DirectProxyWriteStatus = .invalid
public var readStatusDescription: String {
return readStatus.description
}
public var writeStatusDescription: String {
return writeStatus.description
}
/**
Begin reading and processing data from the socket.
- note: Since there is nothing to read and process before forwarding data, this just calls `delegate?.didReceiveRequest`.
*/
override public func openSocket() {
super.openSocket()
guard !isCancelled else {
return
}
if let address = socket.destinationIPAddress, let port = socket.destinationPort {
session = ConnectSession(host: address.presentation, port: Int(port.value))
observer?.signal(.receivedRequest(session!, on: self))
delegate?.didReceive(session: session!, from: self)
} else {
forceDisconnect()
}
}
/**
Response to the `AdapterSocket` on the other side of the `Tunnel` which has succefully connected to the remote server.
- parameter adapter: The `AdapterSocket`.
*/
override public func respondTo(adapter: AdapterSocket) {
super.respondTo(adapter: adapter)
guard !isCancelled else {
return
}
readStatus = .forwarding
writeStatus = .forwarding
observer?.signal(.readyForForward(self))
delegate?.didBecomeReadyToForwardWith(socket: self)
}
/**
The socket did read some data.
- parameter data: The data read from the socket.
- parameter from: The socket where the data is read from.
*/
override open func didRead(data: Data, from: RawTCPSocketProtocol) {
super.didRead(data: data, from: from)
delegate?.didRead(data: data, from: self)
}
/**
The socket did send some data.
- parameter data: The data which have been sent to remote (acknowledged). Note this may not be available since the data may be released to save memory.
- parameter by: The socket where the data is sent out.
*/
override open func didWrite(data: Data?, by: RawTCPSocketProtocol) {
super.didWrite(data: data, by: by)
delegate?.didWrite(data: data, by: self)
}
}

View File

@@ -0,0 +1,207 @@
import Foundation
public class HTTPProxySocket: ProxySocket {
enum HTTPProxyReadStatus: CustomStringConvertible {
case invalid,
readingFirstHeader,
pendingFirstHeader,
readingHeader,
readingContent,
stopped
var description: String {
switch self {
case .invalid:
return "invalid"
case .readingFirstHeader:
return "reading first header"
case .pendingFirstHeader:
return "waiting to send first header"
case .readingHeader:
return "reading header (forwarding)"
case .readingContent:
return "reading content (forwarding)"
case .stopped:
return "stopped"
}
}
}
enum HTTPProxyWriteStatus: CustomStringConvertible {
case invalid,
sendingConnectResponse,
forwarding,
stopped
var description: String {
switch self {
case .invalid:
return "invalid"
case .sendingConnectResponse:
return "sending response header for CONNECT"
case .forwarding:
return "waiting to begin forwarding data"
case .stopped:
return "stopped"
}
}
}
/// The remote host to connect to.
public var destinationHost: String!
/// The remote port to connect to.
public var destinationPort: Int!
private var currentHeader: HTTPHeader!
private let scanner: HTTPStreamScanner = HTTPStreamScanner()
private var readStatus: HTTPProxyReadStatus = .invalid
private var writeStatus: HTTPProxyWriteStatus = .invalid
public var isConnectCommand = false
public var readStatusDescription: String {
return readStatus.description
}
public var writeStatusDescription: String {
return writeStatus.description
}
/**
Begin reading and processing data from the socket.
*/
override public func openSocket() {
super.openSocket()
guard !isCancelled else {
return
}
readStatus = .readingFirstHeader
socket.readDataTo(data: Utils.HTTPData.DoubleCRLF)
}
override public func readData() {
guard !isCancelled else {
return
}
// Return the first header we read when the socket was opened if the proxy command is not CONNECT.
if readStatus == .pendingFirstHeader {
delegate?.didRead(data: currentHeader.toData(), from: self)
readStatus = .readingContent
return
}
switch scanner.nextAction {
case .readContent(let length):
readStatus = .readingContent
if length > 0 {
socket.readDataTo(length: length)
} else {
socket.readData()
}
case .readHeader:
readStatus = .readingHeader
socket.readDataTo(data: Utils.HTTPData.DoubleCRLF)
case .stop:
readStatus = .stopped
disconnect()
}
}
// swiftlint:disable function_body_length
// swiftlint:disable cyclomatic_complexity
/**
The socket did read some data.
- parameter data: The data read from the socket.
- parameter from: The socket where the data is read from.
*/
override public func didRead(data: Data, from: RawTCPSocketProtocol) {
super.didRead(data: data, from: from)
let result: HTTPStreamScanner.Result
do {
result = try scanner.input(data)
} catch let error {
disconnect(becauseOf: error)
return
}
switch (readStatus, result) {
case (.readingFirstHeader, .header(let header)):
currentHeader = header
currentHeader.removeProxyHeader()
currentHeader.rewriteToRelativePath()
destinationHost = currentHeader.host
destinationPort = currentHeader.port
isConnectCommand = currentHeader.isConnect
if !isConnectCommand {
readStatus = .pendingFirstHeader
} else {
readStatus = .readingContent
}
session = ConnectSession(host: destinationHost!, port: destinationPort!)
observer?.signal(.receivedRequest(session!, on: self))
delegate?.didReceive(session: session!, from: self)
case (.readingHeader, .header(let header)):
currentHeader = header
currentHeader.removeProxyHeader()
currentHeader.rewriteToRelativePath()
delegate?.didRead(data: currentHeader.toData(), from: self)
case (.readingContent, .content(let content)):
delegate?.didRead(data: content, from: self)
default:
return
}
}
/**
The socket did send some data.
- parameter data: The data which have been sent to remote (acknowledged). Note this may not be available since the data may be released to save memory.
- parameter by: The socket where the data is sent out.
*/
override public func didWrite(data: Data?, by: RawTCPSocketProtocol) {
super.didWrite(data: data, by: by)
switch writeStatus {
case .sendingConnectResponse:
writeStatus = .forwarding
observer?.signal(.readyForForward(self))
delegate?.didBecomeReadyToForwardWith(socket: self)
default:
delegate?.didWrite(data: data, by: self)
}
}
/**
Response to the `AdapterSocket` on the other side of the `Tunnel` which has succefully connected to the remote server.
- parameter adapter: The `AdapterSocket`.
*/
public override func respondTo(adapter: AdapterSocket) {
super.respondTo(adapter: adapter)
guard !isCancelled else {
return
}
if isConnectCommand {
writeStatus = .sendingConnectResponse
write(data: Utils.HTTPData.ConnectSuccessResponse)
} else {
writeStatus = .forwarding
observer?.signal(.readyForForward(self))
delegate?.didBecomeReadyToForwardWith(socket: self)
}
}
}

View File

@@ -0,0 +1,178 @@
import Foundation
/// The socket which encapsulates the logic to handle connection to proxies.
open class ProxySocket: NSObject, SocketProtocol, RawTCPSocketDelegate {
/// Received `ConnectSession`.
public var session: ConnectSession?
public var observer: Observer<ProxySocketEvent>?
private var _cancelled = false
var isCancelled: Bool {
return _cancelled
}
open override var description: String {
if let session = session {
return "<\(typeName) host:\(session.host) port: \(session.port))>"
} else {
return "<\(typeName)>"
}
}
/**
Init a `ProxySocket` with a raw TCP socket.
- parameter socket: The raw TCP socket.
*/
public init(socket: RawTCPSocketProtocol, observe: Bool = true) {
self.socket = socket
super.init()
self.socket.delegate = self
if observe {
observer = ObserverFactory.currentFactory?.getObserverForProxySocket(self)
}
}
/**
Begin reading and processing data from the socket.
*/
open func openSocket() {
guard !isCancelled else {
return
}
observer?.signal(.socketOpened(self))
}
/**
Response to the `AdapterSocket` on the other side of the `Tunnel` which has succefully connected to the remote server.
- parameter adapter: The `AdapterSocket`.
*/
open func respondTo(adapter: AdapterSocket) {
guard !isCancelled else {
return
}
observer?.signal(.askedToResponseTo(adapter, on: self))
}
/**
Read data from the socket.
- warning: This should only be called after the last read is finished, i.e., `delegate?.didReadData()` is called.
*/
open func readData() {
guard !isCancelled else {
return
}
socket.readData()
}
/**
Send data to remote.
- parameter data: Data to send.
- warning: This should only be called after the last write is finished, i.e., `delegate?.didWriteData()` is called.
*/
open func write(data: Data) {
guard !isCancelled else {
return
}
socket.write(data: data)
}
/**
Disconnect the socket elegantly.
*/
open func disconnect(becauseOf error: Error? = nil) {
guard !isCancelled else {
return
}
_status = .disconnecting
_cancelled = true
session?.disconnected(becauseOf: error, by: .proxy)
socket.disconnect()
observer?.signal(.disconnectCalled(self))
}
/**
Disconnect the socket immediately.
*/
open func forceDisconnect(becauseOf error: Error? = nil) {
guard !isCancelled else {
return
}
_status = .disconnecting
_cancelled = true
session?.disconnected(becauseOf: error, by: .proxy)
socket.forceDisconnect()
observer?.signal(.forceDisconnectCalled(self))
}
// MARK: SocketProtocol Implementation
/// The underlying TCP socket transmitting data.
public var socket: RawTCPSocketProtocol!
/// The delegate instance.
weak public var delegate: SocketDelegate?
var _status: SocketStatus = .established
/// The current connection status of the socket.
public var status: SocketStatus {
return _status
}
// MARK: RawTCPSocketDelegate Protocol Implementation
/**
The socket did disconnect.
- parameter socket: The socket which did disconnect.
*/
open func didDisconnectWith(socket: RawTCPSocketProtocol) {
_status = .closed
observer?.signal(.disconnected(self))
delegate?.didDisconnectWith(socket: self)
}
/**
The socket did read some data.
- parameter data: The data read from the socket.
- parameter withTag: The tag given when calling the `readData` method.
- parameter from: The socket where the data is read from.
*/
open func didRead(data: Data, from: RawTCPSocketProtocol) {
observer?.signal(.readData(data, on: self))
}
/**
The socket did send some data.
- parameter data: The data which have been sent to remote (acknowledged). Note this may not be available since the data may be released to save memory.
- parameter from: The socket where the data is sent out.
*/
open func didWrite(data: Data?, by: RawTCPSocketProtocol) {
observer?.signal(.wroteData(data, on: self))
}
/**
The socket did connect to remote.
- note: This never happens for `ProxySocket`.
- parameter socket: The connected socket.
*/
open func didConnectWith(socket: RawTCPSocketProtocol) {
}
}

View File

@@ -0,0 +1,244 @@
import Foundation
public class SOCKS5ProxySocket: ProxySocket {
enum SOCKS5ProxyReadStatus: CustomStringConvertible {
case invalid,
readingVersionIdentifierAndNumberOfMethods,
readingMethods,
readingConnectHeader,
readingIPv4Address,
readingDomainLength,
readingDomain,
readingIPv6Address,
readingPort,
forwarding,
stopped
var description: String {
switch self {
case .invalid:
return "invalid"
case .readingVersionIdentifierAndNumberOfMethods:
return "reading version and methods"
case .readingMethods:
return "reading methods"
case .readingConnectHeader:
return "reading connect header"
case .readingIPv4Address:
return "IPv4 address"
case .readingDomainLength:
return "domain length"
case .readingDomain:
return "domain"
case .readingIPv6Address:
return "IPv6 address"
case .readingPort:
return "reading port"
case .forwarding:
return "forwarding"
case .stopped:
return "stopped"
}
}
}
enum SOCKS5ProxyWriteStatus: CustomStringConvertible {
case invalid,
sendingResponse,
forwarding,
stopped
var description: String {
switch self {
case .invalid:
return "invalid"
case .sendingResponse:
return "sending response"
case .forwarding:
return "forwarding"
case .stopped:
return "stopped"
}
}
}
/// The remote host to connect to.
public var destinationHost: String!
/// The remote port to connect to.
public var destinationPort: Int!
private var readStatus: SOCKS5ProxyReadStatus = .invalid
private var writeStatus: SOCKS5ProxyWriteStatus = .invalid
public var readStatusDescription: String {
return readStatus.description
}
public var writeStatusDescription: String {
return writeStatus.description
}
/**
Begin reading and processing data from the socket.
*/
override public func openSocket() {
super.openSocket()
guard !isCancelled else {
return
}
readStatus = .readingVersionIdentifierAndNumberOfMethods
socket.readDataTo(length: 2)
}
// swiftlint:disable function_body_length
// swiftlint:disable cyclomatic_complexity
/**
The socket did read some data.
- parameter data: The data read from the socket.
- parameter from: The socket where the data is read from.
*/
override public func didRead(data: Data, from: RawTCPSocketProtocol) {
super.didRead(data: data, from: from)
switch readStatus {
case .forwarding:
delegate?.didRead(data: data, from: self)
case .readingVersionIdentifierAndNumberOfMethods:
data.withUnsafeBytes { pointer in
let p = pointer.bindMemory(to: Int8.self)
guard p.baseAddress!.pointee == 5 else {
// TODO: notify observer
self.disconnect()
return
}
guard p.baseAddress!.successor().pointee > 0 else {
// TODO: notify observer
self.disconnect()
return
}
self.readStatus = .readingMethods
self.socket.readDataTo(length: Int(p.baseAddress!.successor().pointee))
}
case .readingMethods:
// TODO: check for 0x00 in read data
let response = Data([0x05, 0x00])
// we would not be able to read anything before the data is written out, so no need to handle the dataWrote event.
write(data: response)
readStatus = .readingConnectHeader
socket.readDataTo(length: 4)
case .readingConnectHeader:
data.withUnsafeBytes { pointer in
let p = pointer.bindMemory(to: Int8.self)
guard p.baseAddress!.pointee == 5 && p.baseAddress!.successor().pointee == 1 else {
// TODO: notify observer
self.disconnect()
return
}
switch p.baseAddress!.advanced(by: 3).pointee {
case 1:
readStatus = .readingIPv4Address
socket.readDataTo(length: 4)
case 3:
readStatus = .readingDomainLength
socket.readDataTo(length: 1)
case 4:
readStatus = .readingIPv6Address
socket.readDataTo(length: 16)
default:
break
}
}
case .readingIPv4Address:
var address = Data(count: Int(INET_ADDRSTRLEN))
_ = data.withUnsafeBytes { data_ptr in
address.withUnsafeMutableBytes { addr_ptr in
inet_ntop(AF_INET, data_ptr.baseAddress!, addr_ptr.bindMemory(to: Int8.self).baseAddress!, socklen_t(INET_ADDRSTRLEN))
}
}
destinationHost = String(data: address, encoding: .utf8)
readStatus = .readingPort
socket.readDataTo(length: 2)
case .readingIPv6Address:
var address = Data(count: Int(INET6_ADDRSTRLEN))
_ = data.withUnsafeBytes { data_ptr in
address.withUnsafeMutableBytes { addr_ptr in
inet_ntop(AF_INET6, data_ptr.baseAddress!, addr_ptr.bindMemory(to: Int8.self).baseAddress!, socklen_t(INET6_ADDRSTRLEN))
}
}
destinationHost = String(data: address, encoding: .utf8)
readStatus = .readingPort
socket.readDataTo(length: 2)
case .readingDomainLength:
readStatus = .readingDomain
socket.readDataTo(length: Int(data.first!))
case .readingDomain:
destinationHost = String(data: data, encoding: .utf8)
readStatus = .readingPort
socket.readDataTo(length: 2)
case .readingPort:
data.withUnsafeBytes {
destinationPort = Int($0.load(as: UInt16.self).bigEndian)
}
readStatus = .forwarding
session = ConnectSession(host: destinationHost, port: destinationPort)
observer?.signal(.receivedRequest(session!, on: self))
delegate?.didReceive(session: session!, from: self)
default:
return
}
}
/**
The socket did send some data.
- parameter data: The data which have been sent to remote (acknowledged). Note this may not be available since the data may be released to save memory.
- parameter from: The socket where the data is sent out.
*/
override public func didWrite(data: Data?, by: RawTCPSocketProtocol) {
super.didWrite(data: data, by: by)
switch writeStatus {
case .forwarding:
delegate?.didWrite(data: data, by: self)
case .sendingResponse:
writeStatus = .forwarding
observer?.signal(.readyForForward(self))
delegate?.didBecomeReadyToForwardWith(socket: self)
default:
return
}
}
/**
Response to the `AdapterSocket` on the other side of the `Tunnel` which has succefully connected to the remote server.
- parameter adapter: The `AdapterSocket`.
*/
override public func respondTo(adapter: AdapterSocket) {
super.respondTo(adapter: adapter)
guard !isCancelled else {
return
}
var responseBytes = [UInt8](repeating: 0, count: 10)
responseBytes[0...3] = [0x05, 0x00, 0x00, 0x01]
let responseData = Data(responseBytes)
writeStatus = .sendingResponse
write(data: responseData)
}
}

View File

@@ -0,0 +1,155 @@
import Foundation
/**
The current connection status of the socket.
- Invalid: The socket is just created but never connects.
- Connecting: The socket is connecting.
- Established: The connection is established.
- Disconnecting: The socket is disconnecting.
- Closed: The socket is closed.
*/
public enum SocketStatus {
/// The socket is just created but never connects.
case invalid,
/// The socket is connecting.
connecting,
/// The connection is established.
established,
/// The socket is disconnecting.
disconnecting,
/// The socket is closed.
closed
}
/// Protocol for socket with various functions.
///
/// Any concrete implementation does not need to be thread-safe.
public protocol SocketProtocol: class {
/// The underlying TCP socket transmitting data.
var socket: RawTCPSocketProtocol! { get }
/// The delegate instance.
var delegate: SocketDelegate? { get set }
/// The current connection status of the socket.
var status: SocketStatus { get }
// /// The description of the currect status.
// var statusDescription: String { get }
/// If the socket is disconnected.
var isDisconnected: Bool { get }
/// The type of the socket.
var typeName: String { get }
var readStatusDescription: String { get }
var writeStatusDescription: String { get }
/**
Read data from the socket.
- warning: This should only be called after the last read is finished, i.e., `delegate?.didReadData()` is called.
*/
func readData()
/**
Send data to remote.
- parameter data: Data to send.
- warning: This should only be called after the last write is finished, i.e., `delegate?.didWriteData()` is called.
*/
func write(data: Data)
/**
Disconnect the socket elegantly.
*/
func disconnect(becauseOf error: Error?)
/**
Disconnect the socket immediately.
*/
func forceDisconnect(becauseOf error: Error?)
}
extension SocketProtocol {
/// If the socket is disconnected.
public var isDisconnected: Bool {
return status == .closed || status == .invalid
}
public var typeName: String {
return String(describing: type(of: self))
}
public var readStatusDescription: String {
return "\(status)"
}
public var writeStatusDescription: String {
return "\(status)"
}
}
/// The delegate protocol to handle the events from a socket.
public protocol SocketDelegate : class {
/**
The socket did connect to remote.
- parameter adapterSocket: The connected socket.
*/
func didConnectWith(adapterSocket: AdapterSocket)
/**
The socket did disconnect.
This should only be called once in the entire lifetime of a socket. After this is called, the delegate will not receive any other events from that socket and the socket should be released.
- parameter socket: The socket which did disconnect.
*/
func didDisconnectWith(socket: SocketProtocol)
/**
The socket did read some data.
- parameter data: The data read from the socket.
- parameter from: The socket where the data is read from.
*/
func didRead(data: Data, from: SocketProtocol)
/**
The socket did send some data.
- parameter data: The data which have been sent to remote (acknowledged). Note this may not be available since the data may be released to save memory.
- parameter by: The socket where the data is sent out.
*/
func didWrite(data: Data?, by: SocketProtocol)
/**
The socket is ready to forward data back and forth.
- parameter socket: The socket which becomes ready to forward data.
*/
func didBecomeReadyToForwardWith(socket: SocketProtocol)
/**
Did receive a `ConnectSession` from local now it is time to connect to remote.
- parameter session: The received `ConnectSession`.
- parameter from: The socket where the `ConnectSession` is received.
*/
func didReceive(session: ConnectSession, from: ProxySocket)
/**
The adapter socket decided to replace itself with a new `AdapterSocket` to connect to remote.
- parameter newAdapter: The new `AdapterSocket` to replace the old one.
*/
func updateAdapterWith(newAdapter: AdapterSocket)
}

View File

@@ -0,0 +1,29 @@
import Foundation
class QueueFactory {
private static let queueKey = DispatchSpecificKey<String>()
static let queue: DispatchQueue = {
let q = DispatchQueue(label: "NEKit.ProcessingQueue")
q.setSpecific(key: QueueFactory.queueKey, value: "NEKit.ProcessingQueue")
return q
}()
static func getQueue() -> DispatchQueue {
return QueueFactory.queue
}
static func onQueue() -> Bool {
return DispatchQueue.getSpecific(key: QueueFactory.queueKey) == "NEKit.ProcessingQueue"
}
static func executeOnQueueSynchronizedly<T>(block: () throws -> T ) rethrows -> T {
if onQueue() {
return try block()
} else {
return try getQueue().sync {
return try block()
}
}
}
}

View File

@@ -0,0 +1,274 @@
import Foundation
protocol TunnelDelegate : class {
func tunnelDidClose(_ tunnel: Tunnel)
}
/// The tunnel forwards data between local and remote.
public class Tunnel: NSObject, SocketDelegate {
/// The status of `Tunnel`.
public enum TunnelStatus: CustomStringConvertible {
case invalid, readingRequest, waitingToBeReady, forwarding, closing, closed
public var description: String {
switch self {
case .invalid:
return "invalid"
case .readingRequest:
return "reading request"
case .waitingToBeReady:
return "waiting to be ready"
case .forwarding:
return "forwarding"
case .closing:
return "closing"
case .closed:
return "closed"
}
}
}
/// The proxy socket.
var proxySocket: ProxySocket
/// The adapter socket connecting to remote.
var adapterSocket: AdapterSocket?
/// The delegate instance.
weak var delegate: TunnelDelegate?
var observer: Observer<TunnelEvent>?
/// Indicating how many socket is ready to forward data.
private var readySignal = 0
/// If the tunnel is closed, i.e., proxy socket and adapter socket are both disconnected.
var isClosed: Bool {
return proxySocket.isDisconnected && (adapterSocket?.isDisconnected ?? true)
}
fileprivate var _cancelled: Bool = false
fileprivate var _stopForwarding = false
public var isCancelled: Bool {
return _cancelled
}
fileprivate var _status: TunnelStatus = .invalid
public var status: TunnelStatus {
return _status
}
public var statusDescription: String {
return status.description
}
override public var description: String {
if let adapterSocket = adapterSocket {
return "<Tunnel proxySocket:\(proxySocket) adapterSocket:\(adapterSocket)>"
} else {
return "<Tunnel proxySocket:\(proxySocket)>"
}
}
init(proxySocket: ProxySocket) {
self.proxySocket = proxySocket
super.init()
self.proxySocket.delegate = self
self.observer = ObserverFactory.currentFactory?.getObserverForTunnel(self)
}
/**
Start running the tunnel.
*/
func openTunnel() {
guard !self.isCancelled else {
return
}
self.proxySocket.openSocket()
self._status = .readingRequest
self.observer?.signal(.opened(self))
}
/**
Close the tunnel elegantly.
*/
func close() {
observer?.signal(.closeCalled(self))
guard !self.isCancelled else {
return
}
self._cancelled = true
self._status = .closing
if !self.proxySocket.isDisconnected {
self.proxySocket.disconnect()
}
if let adapterSocket = self.adapterSocket {
if !adapterSocket.isDisconnected {
adapterSocket.disconnect()
}
}
}
/// Close the tunnel immediately.
///
/// - note: This method is thread-safe.
func forceClose() {
observer?.signal(.forceCloseCalled(self))
guard !self.isCancelled else {
return
}
self._cancelled = true
self._status = .closing
self._stopForwarding = true
if !self.proxySocket.isDisconnected {
self.proxySocket.forceDisconnect()
}
if let adapterSocket = self.adapterSocket {
if !adapterSocket.isDisconnected {
adapterSocket.forceDisconnect()
}
}
}
public func didReceive(session: ConnectSession, from: ProxySocket) {
guard !isCancelled else {
return
}
_status = .waitingToBeReady
observer?.signal(.receivedRequest(session, from: from, on: self))
if !session.isIP() {
_ = Resolver.resolve(hostname: session.host, timeout: Opt.DNSTimeout) { [weak self] resolver, err in
QueueFactory.getQueue().async {
if err != nil {
session.ipAddress = ""
} else {
session.ipAddress = (resolver?.ipv4Result.first)!
}
self?.openAdapter(for: session)
}
}
} else {
session.ipAddress = session.host
openAdapter(for: session)
}
}
fileprivate func openAdapter(for session: ConnectSession) {
guard !isCancelled else {
return
}
let manager = RuleManager.currentManager
let factory = manager.match(session)!
adapterSocket = factory.getAdapterFor(session: session)
adapterSocket!.delegate = self
adapterSocket!.openSocketWith(session: session)
}
public func didBecomeReadyToForwardWith(socket: SocketProtocol) {
guard !isCancelled else {
return
}
readySignal += 1
observer?.signal(.receivedReadySignal(socket, currentReady: readySignal, on: self))
defer {
if let socket = socket as? AdapterSocket {
proxySocket.respondTo(adapter: socket)
}
}
if readySignal == 2 {
_status = .forwarding
proxySocket.readData()
adapterSocket?.readData()
}
}
public func didDisconnectWith(socket: SocketProtocol) {
if !isCancelled {
_stopForwarding = true
close()
}
checkStatus()
}
public func didRead(data: Data, from socket: SocketProtocol) {
if let socket = socket as? ProxySocket {
observer?.signal(.proxySocketReadData(data, from: socket, on: self))
guard !isCancelled else {
return
}
adapterSocket!.write(data: data)
} else if let socket = socket as? AdapterSocket {
observer?.signal(.adapterSocketReadData(data, from: socket, on: self))
guard !isCancelled else {
return
}
proxySocket.write(data: data)
}
}
public func didWrite(data: Data?, by socket: SocketProtocol) {
if let socket = socket as? ProxySocket {
observer?.signal(.proxySocketWroteData(data, by: socket, on: self))
guard !isCancelled else {
return
}
QueueFactory.getQueue().asyncAfter(deadline: DispatchTime.now() + DispatchTimeInterval.microseconds(Opt.forwardReadInterval)) { [weak self] in
self?.adapterSocket?.readData()
}
} else if let socket = socket as? AdapterSocket {
observer?.signal(.adapterSocketWroteData(data, by: socket, on: self))
guard !isCancelled else {
return
}
proxySocket.readData()
}
}
public func didConnectWith(adapterSocket: AdapterSocket) {
guard !isCancelled else {
return
}
observer?.signal(.connectedToRemote(adapterSocket, on: self))
}
public func updateAdapterWith(newAdapter: AdapterSocket) {
guard !isCancelled else {
return
}
observer?.signal(.updatingAdapterSocket(from: adapterSocket!, to: newAdapter, on: self))
adapterSocket = newAdapter
adapterSocket?.delegate = self
}
fileprivate func checkStatus() {
if isClosed {
_status = .closed
observer?.signal(.closed(self))
delegate?.tunnelDidClose(self)
delegate = nil
}
}
}

View File

@@ -0,0 +1,114 @@
import Foundation
public struct Utils {
public struct HTTPData {
public static let DoubleCRLF = "\r\n\r\n".data(using: String.Encoding.utf8)!
public static let CRLF = "\r\n".data(using: String.Encoding.utf8)!
public static let ConnectSuccessResponse = "HTTP/1.1 200 Connection Established\r\n\r\n".data(using: String.Encoding.utf8)!
}
public struct DNS {
// swiftlint:disable:next nesting
public enum QueryType {
// swiftlint:disable:next type_name
case a, aaaa, unspec
}
public static func resolve(_ name: String, type: QueryType = .unspec) -> String {
let remoteHostEnt = gethostbyname2((name as NSString).utf8String, AF_INET)
if remoteHostEnt == nil {
return ""
}
let remoteAddr = UnsafeMutableRawPointer(remoteHostEnt?.pointee.h_addr_list[0])
var output = [Int8](repeating: 0, count: Int(INET6_ADDRSTRLEN))
inet_ntop(AF_INET, remoteAddr, &output, socklen_t(INET6_ADDRSTRLEN))
return NSString(utf8String: output)! as String
}
}
// swiftlint:disable:next type_name
public struct IP {
public static func isIPv4(_ ipAddress: String) -> Bool {
if IPv4ToInt(ipAddress) != nil {
return true
} else {
return false
}
}
public static func isIPv6(_ ipAddress: String) -> Bool {
let utf8Str = (ipAddress as NSString).utf8String
var dst = [UInt8](repeating: 0, count: 16)
return inet_pton(AF_INET6, utf8Str, &dst) == 1
}
public static func isIP(_ ipAddress: String) -> Bool {
return isIPv4(ipAddress) || isIPv6(ipAddress)
}
public static func IPv4ToInt(_ ipAddress: String) -> UInt32? {
let utf8Str = (ipAddress as NSString).utf8String
var dst = in_addr(s_addr: 0)
if inet_pton(AF_INET, utf8Str, &(dst.s_addr)) == 1 {
return UInt32(dst.s_addr)
} else {
return nil
}
}
public static func IPv4ToBytes(_ ipAddress: String) -> [UInt8]? {
if let ipv4int = IPv4ToInt(ipAddress) {
return Utils.toByteArray(ipv4int).reversed()
} else {
return nil
}
}
public static func IPv6ToBytes(_ ipAddress: String) -> [UInt8]? {
let utf8Str = (ipAddress as NSString).utf8String
var dst = [UInt8](repeating: 0, count: 16)
if inet_pton(AF_INET6, utf8Str, &dst) == 1 {
return Utils.toByteArray(dst).reversed()
} else {
return nil
}
}
}
// struct GeoIPLookup {
//
// static func Lookup(_ ipAddress: String) -> String? {
// if Utils.IP.isIP(ipAddress) {
// guard let result = GeoIP.LookUp(ipAddress) else {
// return "--"
// }
// return result.isoCode
// } else {
// return nil
// }
// }
// }
static func toByteArray<T>(_ value: T) -> [UInt8] {
var value = value
return withUnsafeBytes(of: &value) {
Array($0)
}
}
struct Random {
static func fill(data: inout Data, from: Int = 0, to: Int = -1) {
let c = data.count
data.withUnsafeMutableBytes {
arc4random_buf($0.baseAddress!.advanced(by: from), to == -1 ? c - from : to - from)
}
}
static func fill(data: inout Data, from: Int = 0, length: Int) {
fill(data: &data, from: from, to: from + length)
}
}
}

View File

@@ -0,0 +1,93 @@
//
// BinaryDataScanner.swift
// Murphy
//
// Created by Dave Peck on 7/20/14.
// Copyright (c) 2014 Dave Peck. All rights reserved.
//
import Foundation
/*
Toying with tools to help read binary formats.
I've seen lots of approaches in swift that create
an intermediate object per-read (usually another NSData)
but even if these are lightweight under the hood,
it seems like overkill. Plus this taught me about <()> aka <Void>
And it would be nice to have an extension to
NSFileHandle too that does much the same.
*/
public protocol BinaryReadable {
var littleEndian: Self { get }
var bigEndian: Self { get }
}
extension UInt8: BinaryReadable {
public var littleEndian: UInt8 { return self }
public var bigEndian: UInt8 { return self }
}
extension UInt16: BinaryReadable {}
extension UInt32: BinaryReadable {}
extension UInt64: BinaryReadable {}
open class BinaryDataScanner {
let data: Data
let littleEndian: Bool
// let encoding: NSStringEncoding
var remaining: Int {
return data.count - position
}
var position: Int = 0
public init(data: Data, littleEndian: Bool) {
self.data = data
self.littleEndian = littleEndian
}
open func read<T: BinaryReadable>() -> T? {
if remaining < MemoryLayout<T>.size {
return nil
}
let v = data.withUnsafeBytes {
$0.baseAddress!.advanced(by: position).assumingMemoryBound(to: T.self).pointee
}
position += MemoryLayout<T>.size
return littleEndian ? v.littleEndian : v.bigEndian
}
// swiftlint:disable variable_name
open func skip(to n: Int) {
position = n
}
open func advance(by n: Int) {
position += n
}
/* convenience read funcs */
open func readByte() -> UInt8? {
return read()
}
open func read16() -> UInt16? {
return read()
}
open func read32() -> UInt32? {
return read()
}
open func read64() -> UInt64? {
return read()
}
}

View File

@@ -0,0 +1,44 @@
import Foundation
open class Checksum {
public static func computeChecksum(_ data: Data, from start: Int = 0, to end: Int? = nil, withPseudoHeaderChecksum initChecksum: UInt32 = 0) -> UInt16 {
return toChecksum(computeChecksumUnfold(data, from: start, to: end, withPseudoHeaderChecksum: initChecksum))
}
public static func validateChecksum(_ payload: Data, from start: Int = 0, to end: Int? = nil) -> Bool {
let cs = computeChecksumUnfold(payload, from: start, to: end)
return toChecksum(cs) == 0
}
public static func computeChecksumUnfold(_ data: Data, from start: Int = 0, to end: Int? = nil, withPseudoHeaderChecksum initChecksum: UInt32 = 0) -> UInt32 {
let scanner = BinaryDataScanner(data: data, littleEndian: true)
scanner.skip(to: start)
var result: UInt32 = initChecksum
var end = end
if end == nil {
end = data.count
}
while scanner.position + 2 <= end! {
let value = scanner.read16()!
result += UInt32(value)
}
if scanner.position != end {
// data is of odd size
// Intel and ARM are both litten endian
// so just add it
let value = scanner.readByte()!
result += UInt32(value)
}
return result
}
public static func toChecksum(_ checksum: UInt32) -> UInt16 {
var result = checksum
while (result) >> 16 != 0 {
result = result >> 16 + result & 0xFFFF
}
return ~UInt16(result)
}
}

View File

@@ -0,0 +1,44 @@
import Foundation
/**
* The helper wrapping up an HTTP basic authentication credential.
*/
public struct HTTPAuthentication {
/// The username of the credential.
public let username: String
/// The password of the credential.
public let password: String
/**
Initailize the credential with username and password.
- parameter username: The username of the credential.
- parameter password: The password of the credential.
- returns: The credential.
*/
public init(username: String, password: String) {
self.username = username
self.password = password
}
/**
Return the base64 encoded string of the credential.
- returns: The credential encoded with `"\(username):\(password)"`
*/
public func encoding() -> String? {
let auth = "\(username):\(password)"
return auth.data(using: String.Encoding.utf8)?.base64EncodedString(options: NSData.Base64EncodingOptions.endLineWithLineFeed)
}
/**
Return the full header field content for `Authorization` of an HTTP basic authentication.
- returns: The encoded authentication string.
*/
public func authString() -> String {
return "Basic \(encoding()!)"
}
}

View File

@@ -0,0 +1,81 @@
import Foundation
class HTTPStreamScanner {
enum ReadAction {
case readHeader, readContent(Int), stop
}
enum Result {
case header(HTTPHeader), content(Data)
}
enum HTTPStreamScannerError: Error {
case contentIsTooLong, scannerIsStopped, unsupportedStreamType
}
var nextAction: ReadAction = .readHeader
var remainContentLength: Int = 0
var currentHeader: HTTPHeader!
var isConnect: Bool = false
func input(_ data: Data) throws -> Result {
switch nextAction {
case .readHeader:
let header: HTTPHeader
do {
header = try HTTPHeader(headerData: data)
// To temporarily solve a bug in firefox for mac
if currentHeader != nil && header.host != currentHeader.host {
throw HTTPStreamScannerError.unsupportedStreamType
}
} catch let error {
nextAction = .stop
throw error
}
if currentHeader == nil {
if header.isConnect {
isConnect = true
remainContentLength = -1
} else {
isConnect = false
remainContentLength = header.contentLength
}
} else {
remainContentLength = header.contentLength
}
currentHeader = header
setNextAction()
return .header(header)
case .readContent:
remainContentLength -= data.count
if !isConnect && remainContentLength < 0 {
nextAction = .stop
throw HTTPStreamScannerError.contentIsTooLong
}
setNextAction()
return .content(data)
case .stop:
throw HTTPStreamScannerError.scannerIsStopped
}
}
fileprivate func setNextAction() {
switch remainContentLength {
case 0:
nextAction = .readHeader
case _ where remainContentLength < 0:
nextAction = .readContent(-1)
default:
nextAction = .readContent(min(remainContentLength, Opt.MAXHTTPContentBlockLength))
}
}
}

View File

@@ -0,0 +1,57 @@
import Foundation
public class HTTPURL {
public let scheme: String?
public let host: String?
public let port: Int?
// public let path: String
public let relativePath: String
// swiftlint:disable:next force_try
static let urlreg = try! NSRegularExpression(pattern: "^(?:(?:(https?):\\/\\/)?([\\w\\.-]+)(?::(\\d+))?)?(?:\\/(.*))?$", options: NSRegularExpression.Options.caseInsensitive)
init?(string url: String) {
let nsurl = url as NSString
guard let result = HTTPURL.urlreg.firstMatch(in: url, range: NSRange(location: 0, length: nsurl.length)) else {
return nil
}
guard result.numberOfRanges == 5 else {
return nil
}
guard result.range(at: 0).location != NSNotFound else {
return nil
}
var range = result.range(at: 1)
if range.location != NSNotFound {
scheme = nsurl.substring(with: range)
} else {
scheme = nil
}
range = result.range(at: 2)
if range.location != NSNotFound {
host = nsurl.substring(with: range)
} else {
host = nil
}
range = result.range(at: 3)
if range.location != NSNotFound {
port = Int(nsurl.substring(with: range))
} else {
port = nil
}
range = result.range(at: 4)
if range.location != NSNotFound {
relativePath = nsurl.substring(with: range)
} else {
relativePath = ""
}
}
}

View File

@@ -0,0 +1,208 @@
import Foundation
public class IPAddress: CustomStringConvertible, Comparable {
public enum Family {
case IPv4, IPv6
}
public enum Address: Equatable {
case IPv4(in_addr), IPv6(in6_addr)
public var asUInt128: UInt128 {
switch self {
case .IPv4(let addr):
return UInt128(addr.s_addr.byteSwapped)
case .IPv6(var addr):
var upperBits: UInt64 = 0, lowerBits: UInt64 = 0
withUnsafeBytes(of: &addr) {
upperBits = $0.load(as: UInt64.self).byteSwapped
lowerBits = $0.load(fromByteOffset: MemoryLayout<UInt64>.size, as: UInt64.self).byteSwapped
}
return UInt128(upperBits: upperBits, lowerBits: lowerBits)
}
}
}
public let family: Family
public let address: Address
public lazy var presentation: String = { [unowned self] in
switch self.address {
case .IPv4(var addr):
var buffer = [Int8](repeating: 0, count: Int(INET_ADDRSTRLEN))
var p: UnsafePointer<Int8>! = nil
withUnsafePointer(to: &addr) {
p = inet_ntop(AF_INET, $0, &buffer, UInt32(INET_ADDRSTRLEN))
}
return String(cString: p)
case .IPv6(var addr):
var buffer = [Int8](repeating: 0, count: Int(INET6_ADDRSTRLEN))
var p: UnsafePointer<Int8>! = nil
withUnsafePointer(to: &addr) {
p = inet_ntop(AF_INET6, $0, &buffer, UInt32(INET6_ADDRSTRLEN))
}
return String(cString: p)
}
}()
public init(fromInAddr addr: in_addr) {
family = .IPv4
address = .IPv4(addr)
}
public init(fromIn6Addr addr6: in6_addr) {
family = .IPv6
address = .IPv6(addr6)
}
public convenience init?(fromString string: String) {
var addr = in_addr()
if (string.withCString {
return inet_pton(AF_INET, $0, &addr)
}) == 1 {
self.init(fromInAddr: addr)
presentation = string
} else {
var addr6 = in6_addr()
if (string.withCString {
return inet_pton(AF_INET6, $0, &addr6)
}) == 1 {
self.init(fromIn6Addr: addr6)
presentation = string
} else {
return nil
}
}
}
public convenience init(ipv4InNetworkOrder: UInt32) {
let addr = in_addr(s_addr: ipv4InNetworkOrder)
self.init(fromInAddr: addr)
}
public convenience init(ipv6InNetworkOrder: UInt128) {
var ip = ipv6InNetworkOrder
var addr = in6_addr()
withUnsafeBytes(of: &ip) { ipptr in
withUnsafeMutableBytes(of: &addr) { addrptr in
addrptr.storeBytes(of: ipptr.load(fromByteOffset: MemoryLayout<UInt64>.size, as: UInt64.self), toByteOffset: 0, as: UInt64.self)
addrptr.storeBytes(of: ipptr.load(as: UInt64.self), toByteOffset: MemoryLayout<UInt64>.size, as: UInt64.self)
}
}
self.init(fromIn6Addr: addr)
}
public convenience init(fromBytesInNetworkOrder ptr: UnsafeRawPointer, family: Family = .IPv4) {
switch family {
case .IPv4:
let addr = ptr.assumingMemoryBound(to: in_addr.self).pointee
self.init(fromInAddr: addr)
case .IPv6:
let addr6 = ptr.assumingMemoryBound(to: in6_addr.self).pointee
self.init(fromIn6Addr: addr6)
}
}
public var description: String {
return presentation
}
public var dataInNetworkOrder: Data {
var outputData: Data? = nil
withBytesInNetworkOrder {
outputData = Data($0)
}
return outputData!
}
public var UInt32InNetworkOrder: UInt32? {
switch self.address {
case .IPv4(let addr):
return addr.s_addr
default:
return nil
}
}
public var UInt128InNetworkOrder: UInt128? {
return self.address.asUInt128.byteSwapped
}
public func withBytesInNetworkOrder<U>(_ body: (UnsafeRawBufferPointer) throws -> U) rethrows -> U {
switch address {
case .IPv4(var addr):
return try withUnsafeBytes(of: &addr, body)
case .IPv6(var addr):
return try withUnsafeBytes(of: &addr, body)
}
}
public func advanced(by interval: IPInterval) -> IPAddress? {
switch (interval, address) {
case (.IPv4(let range), .IPv4(let addr)):
return IPAddress(ipv4InNetworkOrder: (addr.s_addr.byteSwapped &+ range).byteSwapped)
case (.IPv6(let range), .IPv6):
return IPAddress(ipv6InNetworkOrder: (address.asUInt128 &+ range).byteSwapped)
default:
return nil
}
}
public func advanced(by interval: UInt) -> IPAddress? {
switch self.address {
case .IPv4(let addr):
return IPAddress(ipv4InNetworkOrder: (addr.s_addr.byteSwapped &+ UInt32(interval)).byteSwapped)
case .IPv6:
return IPAddress(ipv6InNetworkOrder: (address.asUInt128 &+ UInt128(interval)).byteSwapped)
}
}
}
public func == (lhs: IPAddress, rhs: IPAddress) -> Bool {
return lhs.address == rhs.address
}
// Comparing IP addresses of different families are undefined.
// But currently, IPv4 is considered smaller than IPv6 address. Do NOT depend on this behavior.
public func < (lhs: IPAddress, rhs: IPAddress) -> Bool {
switch (lhs.address, rhs.address) {
case (.IPv4(let addrl), .IPv4(let addrr)):
return addrl.s_addr.byteSwapped < addrr.s_addr.byteSwapped
case (.IPv6(var addrl), .IPv6(var addrr)):
let ms = MemoryLayout.size(ofValue: addrl)
return (withUnsafeBytes(of: &addrl) { ptrl in
withUnsafeBytes(of: &addrr) { ptrr in
return memcmp(ptrl.baseAddress!, ptrr.baseAddress!, ms)
}
}) < 0
case (.IPv4, .IPv6):
return true
case (.IPv6, .IPv4):
return false
}
}
public func == (lhs: IPAddress.Address, rhs: IPAddress.Address) -> Bool {
switch (lhs, rhs) {
case (.IPv4(let addrl), .IPv4(let addrr)):
return addrl.s_addr == addrr.s_addr
case (.IPv6(let addrl), .IPv6(let addrr)):
return addrl.__u6_addr.__u6_addr32 == addrr.__u6_addr.__u6_addr32
default:
return false
}
}
extension IPAddress: Hashable {
public func hash(into hasher: inout Hasher) {
switch address {
case .IPv4(let addr):
return hasher.combine(addr.s_addr.hashValue)
case .IPv6(var addr):
return withUnsafeBytes(of: &addr) {
return hasher.combine(bytes: $0)
}
}
}
}

View File

@@ -0,0 +1,5 @@
import Foundation
public enum IPInterval {
case IPv4(UInt32), IPv6(UInt128)
}

View File

@@ -0,0 +1,50 @@
import Foundation
public enum IPMask {
case IPv4(UInt32), IPv6(UInt128)
func mask(baseIP: IPAddress) -> (IPAddress, IPAddress)? {
switch (self, baseIP.address) {
case (.IPv4(var m), .IPv4(let addr)):
guard m <= 32 else {
return nil
}
if m == 32 {
return (baseIP, baseIP)
}
if m == 0 {
return (IPAddress(ipv4InNetworkOrder: 0), IPAddress(ipv4InNetworkOrder: UInt32.max))
}
m = 32 - m
let base = (addr.s_addr.byteSwapped >> m) << m
let end = base | ~((UInt32.max >> m) << m)
let b = IPAddress(ipv4InNetworkOrder: base.byteSwapped)
let e = IPAddress(ipv4InNetworkOrder: end.byteSwapped)
return (b, e)
case (.IPv6(var m), .IPv6):
guard m <= 128 else {
return nil
}
if m == 128 {
return (baseIP, baseIP)
}
if m == 0 {
return (IPAddress(ipv6InNetworkOrder: 0), IPAddress(ipv6InNetworkOrder: UInt128.max))
}
m = 128 - m
let base = (baseIP.address.asUInt128.byteSwapped >> m) << m
let end = base | ~((UInt128.max >> m) << m)
let b = IPAddress(ipv6InNetworkOrder: base.byteSwapped)
let e = IPAddress(ipv6InNetworkOrder: end.byteSwapped)
return (b, e)
default:
return nil
}
}
}

View File

@@ -0,0 +1,47 @@
import Foundation
/**
The pool is build to hold fake ips.
- note: It is NOT thread-safe.
*/
public final class IPPool {
let family: IPAddress.Family
let range: IPRange
var currentEnd: IPAddress
var pool: [IPAddress] = []
public init(range: IPRange) {
family = range.family
self.range = range
currentEnd = range.startIP
}
func fetchIP() -> IPAddress? {
if pool.count == 0 {
if range.contains(ip: currentEnd) {
defer {
currentEnd = currentEnd.advanced(by: 1)!
}
return currentEnd
} else {
return nil
}
}
return pool.removeLast()
}
func release(ip: IPAddress) {
guard ip.family == family else {
return
}
pool.append(ip)
}
func contains(ip: IPAddress) -> Bool {
return range.contains(ip: ip)
}
}

View File

@@ -0,0 +1,122 @@
import Foundation
public enum IPRangeError: Error {
case invalidCIDRFormat, invalidRangeFormat, invalidRange, invalidFormat, addressIncompatible, noRange, intervalInvalid, invalidMask
}
public class IPRange {
public let startIP: IPAddress
// including, so we can include 255.255.255.255 in range.
public let endIP: IPAddress
public let family: IPAddress.Family
public init(startIP: IPAddress, endIP: IPAddress) throws {
guard startIP.family == endIP.family else {
throw IPRangeError.addressIncompatible
}
guard startIP <= endIP else {
throw IPRangeError.invalidRange
}
self.startIP = startIP
self.endIP = endIP
family = startIP.family
}
public convenience init(startIP: IPAddress, interval: IPInterval) throws {
guard let endIP = startIP.advanced(by: interval) else {
throw IPRangeError.intervalInvalid
}
try self.init(startIP: startIP, endIP: endIP)
}
public convenience init(startIP: IPAddress, mask: IPMask) throws {
guard let (startIP, endIP) = mask.mask(baseIP: startIP) else {
throw IPRangeError.invalidMask
}
try self.init(startIP: startIP, endIP: endIP)
}
public func contains(ip: IPAddress) -> Bool {
guard ip.family == family else {
return false
}
return ip >= startIP && ip <= endIP
}
}
extension IPRange {
public convenience init(withCIDRString rep: String) throws {
let info = rep.components(separatedBy: "/")
guard info.count == 2 else {
throw IPRangeError.invalidCIDRFormat
}
guard let ip = IPAddress(fromString: info[0]) else {
throw IPRangeError.invalidCIDRFormat
}
var mask: IPMask
switch ip.family {
case .IPv4:
guard let m = UInt32(info[1]) else {
throw IPRangeError.invalidCIDRFormat
}
mask = IPMask.IPv4(m)
case .IPv6:
guard let m6 = try? UInt128(info[1]) else {
throw IPRangeError.invalidCIDRFormat
}
mask = IPMask.IPv6(m6)
}
try self.init(startIP: ip, mask: mask)
}
public convenience init(withRangeString rep: String) throws {
let info = rep.components(separatedBy: "+")
guard info.count == 2 else {
throw IPRangeError.invalidRangeFormat
}
guard let startIP = IPAddress(fromString: info[0]) else {
throw IPRangeError.invalidRangeFormat
}
var interval: IPInterval
switch startIP.family {
case .IPv4:
guard let m = UInt32(info[1]) else {
throw IPRangeError.invalidRangeFormat
}
interval = IPInterval.IPv4(m)
case .IPv6:
guard let m6 = try? UInt128(info[1]) else {
throw IPRangeError.invalidRangeFormat
}
interval = IPInterval.IPv6(m6)
}
try self.init(startIP: startIP, interval: interval)
}
public convenience init(withString rep: String) throws {
if rep.contains("/") {
try self.init(withCIDRString: rep)
} else if rep.contains("+") {
try self.init(withRangeString: rep)
} else {
guard let ip = IPAddress(fromString: rep) else {
throw IPRangeError.invalidFormat
}
try self.init(startIP: ip, endIP: ip)
}
}
}

View File

@@ -0,0 +1,77 @@
import Foundation
/// Represents the port number of IP protocol.
public struct Port: CustomStringConvertible, ExpressibleByIntegerLiteral {
public typealias IntegerLiteralType = UInt16
fileprivate var inport: UInt16
/**
Initialize a new instance with the port number in network byte order.
- parameter portInNetworkOrder: The port number in network byte order.
- returns: The initailized port.
*/
public init(portInNetworkOrder: UInt16) {
self.inport = portInNetworkOrder
}
/**
Initialize a new instance with the port number.
- parameter port: The port number.
- returns: The initailized port.
*/
public init(port: UInt16) {
self.init(portInNetworkOrder: NSSwapHostShortToBig(port))
}
public init(integerLiteral value: Port.IntegerLiteralType) {
self.init(port: value)
}
/**
Initialize a new instance with data in network byte order.
- parameter bytesInNetworkOrder: The port data in network byte order.
- returns: The initailized port.
*/
public init(bytesInNetworkOrder: UnsafeRawPointer) {
self.init(portInNetworkOrder: bytesInNetworkOrder.load(as: UInt16.self))
}
public var description: String {
return "<Port \(value)>"
}
/// The port number.
public var value: UInt16 {
return NSSwapBigShortToHost(inport)
}
public var valueInNetworkOrder: UInt16 {
return inport
}
/**
Run a block with the bytes of port in **network order**.
- parameter block: The block to run.
- returns: The value the block returns.
*/
public mutating func withUnsafeBufferPointer<T>(_ block: (UnsafeRawBufferPointer) -> T) -> T {
return withUnsafeBytes(of: &inport) {
return block($0)
}
}
}
public func == (left: Port, right: Port) -> Bool {
return left.valueInNetworkOrder == right.valueInNetworkOrder
}
extension Port: Hashable {}

View File

@@ -0,0 +1,41 @@
import Foundation
open class StreamScanner {
var receivedData: NSMutableData = NSMutableData()
let pattern: Data
let maximumLength: Int
var finished = false
var currentLength: Int {
return receivedData.length
}
public init(pattern: Data, maximumLength: Int) {
self.pattern = pattern
self.maximumLength = maximumLength
}
// I know this is not the most effcient algorithm if there is a large number of NSDatas, but since we only need to find the CRLF in http header (as of now), and it should be ready in the first readData call, there is no need to implement a complicate algorithm which is very likely to be slower in such case.
open func addAndScan(_ data: Data) -> (Data?, Data)? {
guard finished == false else {
return nil
}
receivedData.append(data)
let startind = max(0, receivedData.length - pattern.count - data.count)
let range = receivedData.range(of: pattern, options: .backwards, in: NSRange(location: startind, length: receivedData.length - startind))
if range.location == NSNotFound {
if receivedData.length > maximumLength {
finished = true
return (nil, receivedData as Data)
} else {
return nil
}
} else {
finished = true
let foundEndIndex = range.location + range.length
return (receivedData.subdata(with: NSRange(location: 0, length: foundEndIndex)), receivedData.subdata(with: NSRange(location: foundEndIndex, length: receivedData.length - foundEndIndex)))
}
}
}

View File

@@ -0,0 +1,801 @@
//
// UInt128.swift
//
// An implementation of a 128-bit unsigned integer data type not
// relying on any outside libraries apart from Swift's standard
// library. It also seeks to implement the entirety of the
// UnsignedInteger protocol as well as standard functions supported
// by Swift's native unsigned integer types.
//
// Copyright 2017 Joel Gerber
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// MARK: Error Type
/// An `ErrorType` for `UInt128` data types. It includes cases
/// for errors that can occur during string
/// conversion.
public enum UInt128Errors : Error {
/// Input cannot be converted to a UInt128 value.
case invalidString
}
// MARK: - Data Type
/// A 128-bit unsigned integer value type.
/// Storage is based upon a tuple of 2, 64-bit, unsigned integers.
public struct UInt128 {
// MARK: Instance Properties
/// Internal value is presented as a tuple of 2 64-bit
/// unsigned integers.
internal var value: (upperBits: UInt64, lowerBits: UInt64)
/// Counts up the significant bits in stored data.
public var significantBits: UInt128 {
var significantBits: UInt128 = 0
var bitsToWalk: UInt64 = 0 // The bits to crawl in loop.
// When upperBits > 0, lowerBits are all significant.
if self.value.upperBits > 0 {
bitsToWalk = self.value.upperBits
significantBits = 64
} else if self.value.lowerBits > 0 {
bitsToWalk = self.value.lowerBits
}
// Walk significant bits by shifting right until all bits are equal to 0.
while bitsToWalk > 0 {
bitsToWalk >>= 1
significantBits += 1
}
return significantBits
}
/// Undocumented private variable required for passing this type
/// to a BinaryFloatingPoint type. See FloatingPoint.swift.gyb in
/// the Swift stdlib/public/core directory.
internal var signBitIndex: Int {
return 127 - leadingZeroBitCount
}
// MARK: Initializers
/// Designated initializer for the UInt128 type.
public init(upperBits: UInt64, lowerBits: UInt64) {
value.upperBits = upperBits
value.lowerBits = lowerBits
}
public init() {
self.init(upperBits: 0, lowerBits: 0)
}
public init(_ source: UInt128) {
self.init(upperBits: source.value.upperBits,
lowerBits: source.value.lowerBits)
}
/// Initialize a UInt128 value from a string.
///
/// - parameter source: the string that will be converted into a
/// UInt128 value. Defaults to being analyzed as a base10 number,
/// but can be prefixed with `0b` for base2, `0o` for base8
/// or `0x` for base16.
public init(_ source: String) throws {
guard let result = UInt128._valueFromString(source) else {
throw UInt128Errors.invalidString
}
self = result
}
}
// MARK: - FixedWidthInteger Conformance
extension UInt128 : FixedWidthInteger {
public static var bitWidth : Int { return 128 }
// MARK: Instance Properties
public var nonzeroBitCount: Int {
var nonZeroCount = 0
var shiftWidth = 0
while shiftWidth < 128 {
let shiftedSelf = self &>> shiftWidth
let currentBit = shiftedSelf & 1
if currentBit == 1 {
nonZeroCount += 1
}
shiftWidth += 1
}
return nonZeroCount
}
public var leadingZeroBitCount: Int {
var zeroCount = 0
var shiftWidth = 127
while shiftWidth >= 0 {
let currentBit = self &>> shiftWidth
guard currentBit == 0 else { break }
zeroCount += 1
shiftWidth -= 1
}
return zeroCount
}
/// Returns the big-endian representation of the integer, changing the byte order if necessary.
public var bigEndian: UInt128 {
#if arch(i386) || arch(x86_64) || arch(arm) || arch(arm64)
return self.byteSwapped
#else
return self
#endif
}
/// Returns the little-endian representation of the integer, changing the byte order if necessary.
public var littleEndian: UInt128 {
#if arch(i386) || arch(x86_64) || arch(arm) || arch(arm64)
return self
#else
return self.byteSwapped
#endif
}
/// Returns the current integer with the byte order swapped.
public var byteSwapped: UInt128 {
return UInt128(upperBits: self.value.lowerBits.byteSwapped, lowerBits: self.value.upperBits.byteSwapped)
}
// MARK: Initializers
/// Creates a UInt128 from a given value, with the input's value
/// truncated to a size no larger than what UInt128 can handle.
/// Since the input is constrained to an UInt, no truncation needs
/// to occur, as a UInt is currently 64 bits at the maximum.
public init(_truncatingBits bits: UInt) {
self.init(upperBits: 0, lowerBits: UInt64(bits))
}
/// Creates an integer from its big-endian representation, changing the
/// byte order if necessary.
public init(bigEndian value: UInt128) {
#if arch(i386) || arch(x86_64) || arch(arm) || arch(arm64)
self = value.byteSwapped
#else
self = value
#endif
}
/// Creates an integer from its little-endian representation, changing the
/// byte order if necessary.
public init(littleEndian value: UInt128) {
#if arch(i386) || arch(x86_64) || arch(arm) || arch(arm64)
self = value
#else
self = value.byteSwapped
#endif
}
// MARK: Instance Methods
public func addingReportingOverflow(_ rhs: UInt128) -> (partialValue: UInt128, overflow: Bool) {
var resultOverflow = false
let (lowerBits, lowerOverflow) = self.value.lowerBits.addingReportingOverflow(rhs.value.lowerBits)
var (upperBits, upperOverflow) = self.value.upperBits.addingReportingOverflow(rhs.value.upperBits)
// If the lower bits overflowed, we need to add 1 to upper bits.
if lowerOverflow {
(upperBits, resultOverflow) = upperBits.addingReportingOverflow(1)
}
return (partialValue: UInt128(upperBits: upperBits, lowerBits: lowerBits),
overflow: upperOverflow || resultOverflow)
}
public func subtractingReportingOverflow(_ rhs: UInt128) -> (partialValue: UInt128, overflow: Bool) {
var resultOverflow = false
let (lowerBits, lowerOverflow) = self.value.lowerBits.subtractingReportingOverflow(rhs.value.lowerBits)
var (upperBits, upperOverflow) = self.value.upperBits.subtractingReportingOverflow(rhs.value.upperBits)
// If the lower bits overflowed, we need to subtract (borrow) 1 from the upper bits.
if lowerOverflow {
(upperBits, resultOverflow) = upperBits.subtractingReportingOverflow(1)
}
return (partialValue: UInt128(upperBits: upperBits, lowerBits: lowerBits),
overflow: upperOverflow || resultOverflow)
}
public func multipliedReportingOverflow(by rhs: UInt128) -> (partialValue: UInt128, overflow: Bool) {
let multiplicationResult = self.multipliedFullWidth(by: rhs)
let overflowEncountered = multiplicationResult.high > 0
return (partialValue: multiplicationResult.low,
overflow: overflowEncountered)
}
public func multipliedFullWidth(by other: UInt128) -> (high: UInt128, low: UInt128.Magnitude) {
// Bit mask that facilitates masking the lower 32 bits of a 64 bit UInt.
let lower32 = UInt64(UInt32.max)
// Decompose lhs into an array of 4, 32 significant bit UInt64s.
let lhsArray = [
self.value.upperBits >> 32, /*0*/ self.value.upperBits & lower32, /*1*/
self.value.lowerBits >> 32, /*2*/ self.value.lowerBits & lower32 /*3*/
]
// Decompose rhs into an array of 4, 32 significant bit UInt64s.
let rhsArray = [
other.value.upperBits >> 32, /*0*/ other.value.upperBits & lower32, /*1*/
other.value.lowerBits >> 32, /*2*/ other.value.lowerBits & lower32 /*3*/
]
// The future contents of this array will be used to store segment
// multiplication results.
var resultArray = [[UInt64]].init(
repeating: [UInt64].init(repeating: 0, count: 4), count: 4
)
// Loop through every combination of lhsArray[x] * rhsArray[y]
for rhsSegment in 0 ..< rhsArray.count {
for lhsSegment in 0 ..< lhsArray.count {
let currentValue = lhsArray[lhsSegment] * rhsArray[rhsSegment]
resultArray[lhsSegment][rhsSegment] = currentValue
}
}
// Perform multiplication similar to pen and paper in 64bit, 32bit masked increments.
let bitSegment8 = resultArray[3][3] & lower32
let bitSegment7 = UInt128._variadicAdditionWithOverflowCount(
resultArray[2][3] & lower32,
resultArray[3][2] & lower32,
resultArray[3][3] >> 32) // overflow from bitSegment8
let bitSegment6 = UInt128._variadicAdditionWithOverflowCount(
resultArray[1][3] & lower32,
resultArray[2][2] & lower32,
resultArray[3][1] & lower32,
resultArray[2][3] >> 32, // overflow from bitSegment7
resultArray[3][2] >> 32, // overflow from bitSegment7
bitSegment7.overflowCount)
let bitSegment5 = UInt128._variadicAdditionWithOverflowCount(
resultArray[0][3] & lower32,
resultArray[1][2] & lower32,
resultArray[2][1] & lower32,
resultArray[3][0] & lower32,
resultArray[1][3] >> 32, // overflow from bitSegment6
resultArray[2][2] >> 32, // overflow from bitSegment6
resultArray[3][1] >> 32, // overflow from bitSegment6
bitSegment6.overflowCount)
let bitSegment4 = UInt128._variadicAdditionWithOverflowCount(
resultArray[0][2] & lower32,
resultArray[1][1] & lower32,
resultArray[2][0] & lower32,
resultArray[0][3] >> 32, // overflow from bitSegment5
resultArray[1][2] >> 32, // overflow from bitSegment5
resultArray[2][1] >> 32, // overflow from bitSegment5
resultArray[3][0] >> 32, // overflow from bitSegment5
bitSegment5.overflowCount)
let bitSegment3 = UInt128._variadicAdditionWithOverflowCount(
resultArray[0][1] & lower32,
resultArray[1][0] & lower32,
resultArray[0][2] >> 32, // overflow from bitSegment4
resultArray[1][1] >> 32, // overflow from bitSegment4
resultArray[2][0] >> 32, // overflow from bitSegment4
bitSegment4.overflowCount)
let bitSegment1 = UInt128._variadicAdditionWithOverflowCount(
resultArray[0][0],
resultArray[0][1] >> 32, // overflow from bitSegment3
resultArray[1][0] >> 32, // overflow from bitSegment3
bitSegment3.overflowCount)
// Shift and merge the results into 64 bit groups, adding in overflows as we go.
let lowerLowerBits = UInt128._variadicAdditionWithOverflowCount(
bitSegment8,
bitSegment7.truncatedValue << 32)
let upperLowerBits = UInt128._variadicAdditionWithOverflowCount(
bitSegment7.truncatedValue >> 32,
bitSegment6.truncatedValue,
bitSegment5.truncatedValue << 32,
lowerLowerBits.overflowCount)
let lowerUpperBits = UInt128._variadicAdditionWithOverflowCount(
bitSegment5.truncatedValue >> 32,
bitSegment4.truncatedValue,
bitSegment3.truncatedValue << 32,
upperLowerBits.overflowCount)
let upperUpperBits = UInt128._variadicAdditionWithOverflowCount(
bitSegment3.truncatedValue >> 32,
bitSegment1.truncatedValue,
lowerUpperBits.overflowCount)
// Bring the 64bit unsigned integer results together into a high and low 128bit unsigned integer result.
return (high: UInt128(upperBits: upperUpperBits.truncatedValue, lowerBits: lowerUpperBits.truncatedValue),
low: UInt128(upperBits: upperLowerBits.truncatedValue, lowerBits: lowerLowerBits.truncatedValue))
}
/// Takes a variable amount of 64bit Unsigned Integers and adds them together,
/// tracking the total amount of overflows that occurred during addition.
///
/// - Parameter addends:
/// Variably sized list of UInt64 values.
/// - Returns:
/// A tuple containing the truncated result and a count of the total
/// amount of overflows that occurred during addition.
private static func _variadicAdditionWithOverflowCount(_ addends: UInt64...) -> (truncatedValue: UInt64, overflowCount: UInt64) {
var sum: UInt64 = 0
var overflowCount: UInt64 = 0
addends.forEach { addend in
let interimSum = sum.addingReportingOverflow(addend)
if interimSum.overflow {
overflowCount += 1
}
sum = interimSum.partialValue
}
return (truncatedValue: sum, overflowCount: overflowCount)
}
public func dividedReportingOverflow(by rhs: UInt128) -> (partialValue: UInt128, overflow: Bool) {
guard rhs != 0 else {
return (self, true)
}
let quotient = self.quotientAndRemainder(dividingBy: rhs).quotient
return (quotient, false)
}
public func dividingFullWidth(_ dividend: (high: UInt128, low: UInt128)) -> (quotient: UInt128, remainder: UInt128) {
return self._quotientAndRemainderFullWidth(dividingBy: dividend)
}
public func remainderReportingOverflow(dividingBy rhs: UInt128) -> (partialValue: UInt128, overflow: Bool) {
guard rhs != 0 else {
return (self, true)
}
let remainder = self.quotientAndRemainder(dividingBy: rhs).remainder
return (remainder, false)
}
public func quotientAndRemainder(dividingBy rhs: UInt128) -> (quotient: UInt128, remainder: UInt128) {
return rhs._quotientAndRemainderFullWidth(dividingBy: (high: 0, low: self))
}
/// Provides the quotient and remainder when dividing the provided value by self.
internal func _quotientAndRemainderFullWidth(dividingBy dividend: (high: UInt128, low: UInt128)) -> (quotient: UInt128, remainder: UInt128) {
let divisor = self
let numeratorBitsToWalk: UInt128
if dividend.high > 0 {
numeratorBitsToWalk = dividend.high.significantBits + 128 - 1
} else if dividend.low == 0 {
return (0, 0)
} else {
numeratorBitsToWalk = dividend.low.significantBits - 1
}
// The below algorithm was adapted from:
// https://en.wikipedia.org/wiki/Division_algorithm#Integer_division_.28unsigned.29_with_remainder
precondition(self != 0, "Division by 0")
var quotient = UInt128.min
var remainder = UInt128.min
for numeratorShiftWidth in (0...numeratorBitsToWalk).reversed() {
remainder <<= 1
remainder |= UInt128._bitFromDoubleWidth(at: numeratorShiftWidth, for: dividend)
if remainder >= divisor {
remainder -= divisor
quotient |= 1 << numeratorShiftWidth
}
}
return (quotient, remainder)
}
/// Returns the bit stored at the given position for the provided double width UInt128 input.
///
/// - parameter at: position to grab bit value from.
/// - parameter for: the double width UInt128 data value to grab the
/// bit from.
/// - returns: single bit stored in a UInt128 value.
internal static func _bitFromDoubleWidth(at bitPosition: UInt128, for input: (high: UInt128, low: UInt128)) -> UInt128 {
switch bitPosition {
case 0:
return input.low & 1
case 1...127:
return input.low >> bitPosition & 1
case 128:
return input.high & 1
default:
return input.high >> (bitPosition - 128) & 1
}
}
}
// MARK: - BinaryInteger Conformance
extension UInt128 : BinaryInteger {
// MARK: Instance Properties
public var bitWidth : Int { return 128 }
// MARK: Instance Methods
public var words: [UInt] {
guard self != UInt128.min else {
return []
}
var words: [UInt] = []
for currentWord in 0 ... self.bitWidth / UInt.bitWidth {
let shiftAmount: UInt64 = UInt64(UInt.bitWidth) * UInt64(currentWord)
let mask = UInt64(UInt.max)
var shifted = self
if shiftAmount > 0 {
shifted &>>= UInt128(upperBits: 0, lowerBits: shiftAmount)
}
let masked: UInt128 = shifted & UInt128(upperBits: 0, lowerBits: mask)
words.append(UInt(masked.value.lowerBits))
}
return words
}
public var trailingZeroBitCount: Int {
let mask: UInt128 = 1
var bitsToWalk = self
for currentPosition in 0...128 {
if bitsToWalk & mask == 1 {
return currentPosition
}
bitsToWalk >>= 1
}
return 128
}
// MARK: Initializers
public init?<T : BinaryFloatingPoint>(exactly source: T) {
if source.isZero {
self = UInt128()
}
else if source.exponent < 0 || source.rounded() != source {
return nil
}
else {
self = UInt128(UInt64(source))
}
}
public init<T : BinaryFloatingPoint>(_ source: T) {
self.init(UInt64(source))
}
// MARK: Type Methods
public static func /(_ lhs: UInt128, _ rhs: UInt128) -> UInt128 {
let result = lhs.dividedReportingOverflow(by: rhs)
return result.partialValue
}
public static func /=(_ lhs: inout UInt128, _ rhs: UInt128) {
lhs = lhs / rhs
}
public static func %(_ lhs: UInt128, _ rhs: UInt128) -> UInt128 {
let result = lhs.remainderReportingOverflow(dividingBy: rhs)
return result.partialValue
}
public static func %=(_ lhs: inout UInt128, _ rhs: UInt128) {
lhs = lhs % rhs
}
/// Performs a bitwise AND operation on 2 UInt128 data types.
public static func &=(_ lhs: inout UInt128, _ rhs: UInt128) {
let upperBits = lhs.value.upperBits & rhs.value.upperBits
let lowerBits = lhs.value.lowerBits & rhs.value.lowerBits
lhs = UInt128(upperBits: upperBits, lowerBits: lowerBits)
}
/// Performs a bitwise OR operation on 2 UInt128 data types.
public static func |=(_ lhs: inout UInt128, _ rhs: UInt128) {
let upperBits = lhs.value.upperBits | rhs.value.upperBits
let lowerBits = lhs.value.lowerBits | rhs.value.lowerBits
lhs = UInt128(upperBits: upperBits, lowerBits: lowerBits)
}
/// Performs a bitwise XOR operation on 2 UInt128 data types.
public static func ^=(_ lhs: inout UInt128, _ rhs: UInt128) {
let upperBits = lhs.value.upperBits ^ rhs.value.upperBits
let lowerBits = lhs.value.lowerBits ^ rhs.value.lowerBits
lhs = UInt128(upperBits: upperBits, lowerBits: lowerBits)
}
/// Perform a masked right SHIFT operation self.
///
/// The masking operation will mask `rhs` against the highest
/// shift value that will not cause an overflowing shift before
/// performing the shift. IE: `rhs = 128` will become `rhs = 0`
/// and `rhs = 129` will become `rhs = 1`.
public static func &>>=(_ lhs: inout UInt128, _ rhs: UInt128) {
let shiftWidth = rhs.value.lowerBits & 127
switch shiftWidth {
case 0: return // Do nothing shift.
case 1...63:
let upperBits = lhs.value.upperBits >> shiftWidth
let lowerBits = (lhs.value.lowerBits >> shiftWidth) + (lhs.value.upperBits << (64 - shiftWidth))
lhs = UInt128(upperBits: upperBits, lowerBits: lowerBits)
case 64:
// Shift 64 means move upper bits to lower bits.
lhs = UInt128(upperBits: 0, lowerBits: lhs.value.upperBits)
default:
let lowerBits = lhs.value.upperBits >> (shiftWidth - 64)
lhs = UInt128(upperBits: 0, lowerBits: lowerBits)
}
}
/// Perform a masked left SHIFT operation on self.
///
/// The masking operation will mask `rhs` against the highest
/// shift value that will not cause an overflowing shift before
/// performing the shift. IE: `rhs = 128` will become `rhs = 0`
/// and `rhs = 129` will become `rhs = 1`.
public static func &<<=(_ lhs: inout UInt128, _ rhs: UInt128) {
let shiftWidth = rhs.value.lowerBits & 127
switch shiftWidth {
case 0: return // Do nothing shift.
case 1...63:
let upperBits = (lhs.value.upperBits << shiftWidth) + (lhs.value.lowerBits >> (64 - shiftWidth))
let lowerBits = lhs.value.lowerBits << shiftWidth
lhs = UInt128(upperBits: upperBits, lowerBits: lowerBits)
case 64:
// Shift 64 means move lower bits to upper bits.
lhs = UInt128(upperBits: lhs.value.lowerBits, lowerBits: 0)
default:
let upperBits = lhs.value.lowerBits << (shiftWidth - 64)
lhs = UInt128(upperBits: upperBits, lowerBits: 0)
}
}
}
// MARK: - UnsignedInteger Conformance
extension UInt128 : UnsignedInteger {}
// MARK: - Hashable Conformance
extension UInt128 : Hashable {
public func hash(into hasher: inout Hasher) {
hasher.combine(self.value.lowerBits)
hasher.combine(self.value.upperBits)
}
}
// MARK: - Numeric Conformance
extension UInt128 : Numeric {
public static func +(_ lhs: UInt128, _ rhs: UInt128) -> UInt128 {
precondition(~lhs >= rhs, "Addition overflow!")
let result = lhs.addingReportingOverflow(rhs)
return result.partialValue
}
public static func +=(_ lhs: inout UInt128, _ rhs: UInt128) {
lhs = lhs + rhs
}
public static func -(_ lhs: UInt128, _ rhs: UInt128) -> UInt128 {
precondition(lhs >= rhs, "Integer underflow")
let result = lhs.subtractingReportingOverflow(rhs)
return result.partialValue
}
public static func -=(_ lhs: inout UInt128, _ rhs: UInt128) {
lhs = lhs - rhs
}
public static func *(_ lhs: UInt128, _ rhs: UInt128) -> UInt128 {
let result = lhs.multipliedReportingOverflow(by: rhs)
precondition(!result.overflow, "Multiplication overflow!")
return result.partialValue
}
public static func *=(_ lhs: inout UInt128, _ rhs: UInt128) {
lhs = lhs * rhs
}
}
// MARK: - Equatable Conformance
extension UInt128 : Equatable {
/// Checks if the `lhs` is equal to the `rhs`.
public static func ==(lhs: UInt128, rhs: UInt128) -> Bool {
if lhs.value.lowerBits == rhs.value.lowerBits && lhs.value.upperBits == rhs.value.upperBits {
return true
}
return false
}
}
// MARK: - ExpressibleByIntegerLiteral Conformance
extension UInt128 : ExpressibleByIntegerLiteral {
public init(integerLiteral value: IntegerLiteralType) {
self.init(upperBits: 0, lowerBits: UInt64(value))
}
}
// MARK: - CustomStringConvertible Conformance
extension UInt128 : CustomStringConvertible {
// MARK: Instance Properties
public var description: String {
return self._valueToString()
}
// MARK: Instance Methods
/// Converts the stored value into a string representation.
/// - parameter radix:
/// The radix for the base numbering system you wish to have
/// the type presented in.
/// - parameter uppercase:
/// Determines whether letter components of the outputted string will be in
/// uppercase format or not.
/// - returns:
/// String representation of the stored UInt128 value.
internal func _valueToString(radix: Int = 10, uppercase: Bool = true) -> String {
precondition(radix > 1 && radix < 37, "radix must be within the range of 2-36.")
// Will store the final string result.
var result = String()
// Simple case.
if self == 0 {
result.append("0")
return result
}
// Used as the check for indexing through UInt128 for string interpolation.
var divmodResult = (quotient: self, remainder: UInt128(0))
// Will hold the pool of possible values.
let characterPool = (uppercase) ? "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" : "0123456789abcdefghijklmnopqrstuvwxyz"
// Go through internal value until every base position is string(ed).
repeat {
divmodResult = divmodResult.quotient.quotientAndRemainder(dividingBy: UInt128(radix))
let index = characterPool.index(characterPool.startIndex, offsetBy: Int(divmodResult.remainder))
result.insert(characterPool[index], at: result.startIndex)
} while divmodResult.quotient > 0
return result
}
}
// MARK: - CustomDebugStringConvertible Conformance
extension UInt128 : CustomDebugStringConvertible {
public var debugDescription: String {
return self.description
}
}
// MARK: - Comparable Conformance
extension UInt128 : Comparable {
public static func <(lhs: UInt128, rhs: UInt128) -> Bool {
if lhs.value.upperBits < rhs.value.upperBits {
return true
} else if lhs.value.upperBits == rhs.value.upperBits && lhs.value.lowerBits < rhs.value.lowerBits {
return true
}
return false
}
}
// MARK: - ExpressibleByStringLiteral Conformance
extension UInt128 : ExpressibleByStringLiteral {
// MARK: Initializers
public init(stringLiteral value: StringLiteralType) {
self.init()
if let result = UInt128._valueFromString(value) {
self = result
}
}
// MARK: Type Methods
internal static func _valueFromString(_ value: String) -> UInt128? {
let radix = UInt128._determineRadixFromString(value)
let inputString = radix == 10 ? value : String(value.dropFirst(2))
return UInt128(inputString, radix: radix)
}
internal static func _determineRadixFromString(_ string: String) -> Int {
let radix: Int
if string.hasPrefix("0b") { radix = 2 }
else if string.hasPrefix("0o") { radix = 8 }
else if string.hasPrefix("0x") { radix = 16 }
else { radix = 10 }
return radix
}
}
// MARK: - Deprecated API
extension UInt128 {
/// Initialize a UInt128 value from a string.
///
/// - parameter source: the string that will be converted into a
/// UInt128 value. Defaults to being analyzed as a base10 number,
/// but can be prefixed with `0b` for base2, `0o` for base8
/// or `0x` for base16.
@available(swift, deprecated: 3.2, renamed: "init(_:)")
public static func fromUnparsedString(_ source: String) throws -> UInt128 {
return try UInt128.init(source)
}
}
// MARK: - BinaryFloatingPoint Interworking
extension BinaryFloatingPoint {
public init(_ value: UInt128) {
precondition(value.value.upperBits == 0, "Value is too large to fit into a BinaryFloatingPoint until a 128bit BinaryFloatingPoint type is defined.")
self.init(value.value.lowerBits)
}
public init?(exactly value: UInt128) {
if value.value.upperBits > 0 {
return nil
}
self = Self(value.value.lowerBits)
}
}
// MARK: - String Interworking
extension String {
/// Creates a string representing the given value in base 10, or some other
/// specified base.
///
/// - Parameters:
/// - value: The UInt128 value to convert to a string.
/// - radix: The base to use for the string representation. `radix` must be
/// at least 2 and at most 36. The default is 10.
/// - uppercase: Pass `true` to use uppercase letters to represent numerals
/// or `false` to use lowercase letters. The default is `false`.
public init(_ value: UInt128, radix: Int = 10, uppercase: Bool = false) {
self = value._valueToString(radix: radix, uppercase: uppercase)
}
}

View File

@@ -0,0 +1,161 @@
import Foundation
import dnssd
private let dict = SafeDict<Resolver>()
public enum ResolveType: DNSServiceProtocol {
case ipv4 = 1, ipv6 = 2, any = 3
}
public class Resolver {
public static var queue: DispatchQueue {
get {
return _queue
}
set {
_queue.setSpecific(key: queueKey, value: "")
_queue = newValue
_queue.setSpecific(key: queueKey, value: "ResolverQueue")
}
}
fileprivate static let queueKey = DispatchSpecificKey<String>()
private static var _queue = {
return DispatchQueue(label: "ResolverQueue")
}()
public static var activeCount: Int {
return dict.count
}
public let hostname: String
fileprivate let resolveType: ResolveType
fileprivate let firstResult: Bool
public var ipv4Result: [String] = []
public var ipv6Result: [String] = []
public var result: [String] {
return ipv4Result + ipv6Result
}
var cancelled = false
fileprivate var ref: DNSServiceRef?
fileprivate var id: UnsafeMutablePointer<Int>?
fileprivate var completionHandler: ((Resolver?, DNSServiceErrorType?)->())!
fileprivate let timeout: Int
fileprivate let timer = DispatchSource.makeTimerSource(queue: Resolver.queue)
public static func resolve(hostname: String, qtype: ResolveType = .ipv4, firstResult: Bool = true, timeout: Int = 3, completionHanlder: @escaping (Resolver?, DNSServiceErrorType?)->()) -> Bool {
let resolver = Resolver(hostname: hostname, qtype: qtype, firstResult: firstResult, timeout: timeout)
resolver.completionHandler = completionHanlder
return resolver.resolve()
}
fileprivate init(hostname: String, qtype: ResolveType, firstResult: Bool, timeout: Int) {
self.hostname = hostname
self.resolveType = qtype
self.firstResult = firstResult
self.timeout = timeout
}
fileprivate func resolve() -> Bool {
guard ref == nil else {
return false
}
var result: Bool = false
let action = DispatchWorkItem {
self.id = dict.insert(value: self)
self.timer.schedule(deadline: DispatchTime.now() + DispatchTimeInterval.seconds(self.timeout))
self.timer.setEventHandler(handler: self.timeoutHandler)
result = self.hostname.withCString { (ptr: UnsafePointer<Int8>) in
guard DNSServiceGetAddrInfo(&self.ref, 0, 0, self.resolveType.rawValue, self.hostname, { (sdRef, flags, interfaceIndex, errorCode, ptr, address, ttl, context) in
// Note this callback block will be called on `Resolver.queue`.
guard let resolver = dict.get(context!.bindMemory(to: Int.self, capacity: 1)) else {
NSLog("Error: Got some unknown resolver.")
return
}
guard !resolver.cancelled else {
return
}
guard errorCode == DNSServiceErrorType(kDNSServiceErr_NoError) else {
resolver.release()
resolver.completionHandler(nil, errorCode)
return
}
switch (Int32(address!.pointee.sa_family)) {
case AF_INET:
var buffer = [Int8](repeating: 0, count: Int(INET_ADDRSTRLEN))
_ = buffer.withUnsafeMutableBufferPointer { buf in
address?.withMemoryRebound(to: sockaddr_in.self, capacity: 1) { addr in
var sin_addr = addr.pointee.sin_addr
inet_ntop(AF_INET, &sin_addr, buf.baseAddress, socklen_t(INET_ADDRSTRLEN))
let addr = String(cString: buf.baseAddress!)
resolver.ipv4Result.append(addr)
}
}
case AF_INET6:
var buffer = [Int8](repeating: 0, count: Int(INET6_ADDRSTRLEN))
_ = buffer.withUnsafeMutableBufferPointer { buf in
address?.withMemoryRebound(to: sockaddr_in6.self, capacity: 1) { addr in
var sin6_addr = addr.pointee.sin6_addr
inet_ntop(AF_INET6, &sin6_addr, buf.baseAddress, socklen_t(INET6_ADDRSTRLEN))
let addr = String(cString: buf.baseAddress!)
resolver.ipv6Result.append(addr)
}
}
default:
break
}
if (resolver.firstResult || flags & DNSServiceFlags(kDNSServiceFlagsMoreComing) == 0) {
resolver.release()
return resolver.completionHandler(resolver, nil)
}
}, self.id) == DNSServiceErrorType(kDNSServiceErr_NoError) else {
return false
}
DNSServiceSetDispatchQueue(self.ref, Resolver.queue)
self.timer.resume()
return true
}
}
if DispatchQueue.getSpecific(key: Resolver.queueKey) == "ResolverQueue" {
action.perform()
} else {
Resolver.queue.sync(execute: action)
}
return result
}
func timeoutHandler() {
if !cancelled {
release()
completionHandler(nil, DNSServiceErrorType(kDNSServiceErr_Timeout))
}
}
func release() {
cancelled = true
timer.cancel()
if ref != nil {
DNSServiceRefDeallocate(ref)
ref = nil
}
if id != nil {
_ = dict.remove(id!)
id = nil
}
}
}

View File

@@ -0,0 +1,40 @@
import Foundation
/// This class is not thread-safe.
class SafeDict<T> {
private var dict: [Int:T] = [:]
private var curr = 0
var count: Int {
return dict.count
}
func insert(value: T) -> UnsafeMutablePointer<Int> {
let ptr = UnsafeMutablePointer<Int>.allocate(capacity: 1)
ptr.pointee = curr
dict[curr] = value
curr += 1
return ptr
}
func get(_ id: Int) -> T? {
return dict[id]
}
func get(_ id: UnsafePointer<Int>) -> T? {
return get(id.pointee)
}
func remove(_ id: Int) -> T? {
return dict.removeValue(forKey: id)
}
func remove(_ id: UnsafeMutablePointer<Int>) -> T? {
defer {
id.deallocate()
}
return remove(id.pointee)
}
}