337 lines
9.8 KiB
Swift
Executable File
337 lines
9.8 KiB
Swift
Executable File
import Foundation
|
|
import NetworkExtension
|
|
|
|
/// The TCP socket build upon `NWTCPConnection`.
|
|
///
|
|
/// - warning: This class is not thread-safe.
|
|
public class NWTCPSocket: NSObject, RawTCPSocketProtocol {
|
|
private var connection: NWTCPConnection?
|
|
|
|
private var writePending = false
|
|
private var closeAfterWriting = false
|
|
private var cancelled = false
|
|
|
|
private var scanner: StreamScanner!
|
|
private var scanning: Bool = false
|
|
private var readDataPrefix: Data?
|
|
|
|
// MARK: RawTCPSocketProtocol implementation
|
|
|
|
/// The `RawTCPSocketDelegate` instance.
|
|
weak open var delegate: RawTCPSocketDelegate?
|
|
|
|
/// If the socket is connected.
|
|
public var isConnected: Bool {
|
|
return connection != nil && connection!.state == .connected
|
|
}
|
|
|
|
/// The source address.
|
|
///
|
|
/// - note: Always returns `nil`.
|
|
public var sourceIPAddress: IPAddress? {
|
|
return nil
|
|
}
|
|
|
|
/// The source port.
|
|
///
|
|
/// - note: Always returns `nil`.
|
|
public var sourcePort: Port? {
|
|
return nil
|
|
}
|
|
|
|
/// The destination address.
|
|
///
|
|
/// - note: Always returns `nil`.
|
|
public var destinationIPAddress: IPAddress? {
|
|
return nil
|
|
}
|
|
|
|
/// The destination port.
|
|
///
|
|
/// - note: Always returns `nil`.
|
|
public var destinationPort: Port? {
|
|
return nil
|
|
}
|
|
|
|
/**
|
|
Connect to remote host.
|
|
|
|
- parameter host: Remote host.
|
|
- parameter port: Remote port.
|
|
- parameter enableTLS: Should TLS be enabled.
|
|
- parameter tlsSettings: The settings of TLS.
|
|
|
|
- throws: Never throws.
|
|
*/
|
|
public func connectTo(host: String, port: Int, enableTLS: Bool, tlsSettings: [AnyHashable: Any]?) throws {
|
|
let endpoint = NWHostEndpoint(hostname: host, port: "\(port)")
|
|
let tlsParameters = NWTLSParameters()
|
|
if let tlsSettings = tlsSettings as? [String: AnyObject] {
|
|
tlsParameters.setValuesForKeys(tlsSettings)
|
|
}
|
|
|
|
guard let connection = RawSocketFactory.TunnelProvider?.createTCPConnection(to: endpoint, enableTLS: enableTLS, tlsParameters: tlsParameters, delegate: nil) else {
|
|
// This should only happen when the extension is already stopped and `RawSocketFactory.TunnelProvider` is set to `nil`.
|
|
return
|
|
}
|
|
|
|
self.connection = connection
|
|
connection.addObserver(self, forKeyPath: "state", options: [.initial, .new], context: nil)
|
|
}
|
|
|
|
/**
|
|
Disconnect the socket.
|
|
|
|
The socket will disconnect elegantly after any queued writing data are successfully sent.
|
|
*/
|
|
public func disconnect() {
|
|
cancelled = true
|
|
|
|
if connection == nil || connection!.state == .cancelled {
|
|
delegate?.didDisconnectWith(socket: self)
|
|
} else {
|
|
closeAfterWriting = true
|
|
checkStatus()
|
|
}
|
|
}
|
|
|
|
/**
|
|
Disconnect the socket immediately.
|
|
*/
|
|
public func forceDisconnect() {
|
|
cancelled = true
|
|
|
|
if connection == nil || connection!.state == .cancelled {
|
|
delegate?.didDisconnectWith(socket: self)
|
|
} else {
|
|
cancel()
|
|
}
|
|
}
|
|
|
|
/**
|
|
Send data to remote.
|
|
|
|
- parameter data: Data to send.
|
|
- warning: This should only be called after the last write is finished, i.e., `delegate?.didWriteData()` is called.
|
|
*/
|
|
public func write(data: Data) {
|
|
guard !cancelled else {
|
|
return
|
|
}
|
|
|
|
guard data.count > 0 else {
|
|
QueueFactory.getQueue().async {
|
|
self.delegate?.didWrite(data: data, by: self)
|
|
}
|
|
return
|
|
}
|
|
|
|
send(data: data)
|
|
}
|
|
|
|
/**
|
|
Read data from the socket.
|
|
|
|
- warning: This should only be called after the last read is finished, i.e., `delegate?.didReadData()` is called.
|
|
*/
|
|
public func readData() {
|
|
guard !cancelled else {
|
|
return
|
|
}
|
|
|
|
connection!.readMinimumLength(1, maximumLength: Opt.MAXNWTCPSocketReadDataSize) { data, error in
|
|
guard error == nil else {
|
|
let e = error! as NSError
|
|
let ignore = (
|
|
e.domain == "kNWErrorDomainPOSIX" && e.code == POSIXError.ECANCELED.rawValue // Operation canceled
|
|
|| e.domain == NSPOSIXErrorDomain && e.code == POSIXError.ENOTCONN.rawValue // Socket is not connected
|
|
)
|
|
if !ignore {
|
|
DDLogError("NWTCPSocket got an error when reading data: \(String(describing: error))")
|
|
}
|
|
self.queueCall {
|
|
self.disconnect()
|
|
}
|
|
return
|
|
}
|
|
|
|
self.readCallback(data: data)
|
|
}
|
|
}
|
|
|
|
/**
|
|
Read specific length of data from the socket.
|
|
|
|
- parameter length: The length of the data to read.
|
|
- warning: This should only be called after the last read is finished, i.e., `delegate?.didReadData()` is called.
|
|
*/
|
|
public func readDataTo(length: Int) {
|
|
guard !cancelled else {
|
|
return
|
|
}
|
|
|
|
connection!.readLength(length) { data, error in
|
|
guard error == nil else {
|
|
DDLogError("NWTCPSocket got an error when reading data: \(String(describing: error))")
|
|
self.queueCall {
|
|
self.disconnect()
|
|
}
|
|
return
|
|
}
|
|
|
|
self.readCallback(data: data)
|
|
}
|
|
}
|
|
|
|
/**
|
|
Read data until a specific pattern (including the pattern).
|
|
|
|
- parameter data: The pattern.
|
|
- warning: This should only be called after the last read is finished, i.e., `delegate?.didReadData()` is called.
|
|
*/
|
|
public func readDataTo(data: Data) {
|
|
readDataTo(data: data, maxLength: 0)
|
|
}
|
|
|
|
// Actually, this method is available as `- (void)readToPattern:(id)arg1 maximumLength:(unsigned int)arg2 completionHandler:(id /* block */)arg3;`
|
|
// which is sadly not available in public header for some reason I don't know.
|
|
// I don't want to do it myself since This method is not trival to implement and I don't like reinventing the wheel.
|
|
// Here is only the most naive version, which may not be the optimal if using with large data blocks.
|
|
/**
|
|
Read data until a specific pattern (including the pattern).
|
|
|
|
- parameter data: The pattern.
|
|
- parameter maxLength: The max length of data to scan for the pattern.
|
|
- warning: This should only be called after the last read is finished, i.e., `delegate?.didReadData()` is called.
|
|
*/
|
|
public func readDataTo(data: Data, maxLength: Int) {
|
|
guard !cancelled else {
|
|
return
|
|
}
|
|
|
|
var maxLength = maxLength
|
|
if maxLength == 0 {
|
|
maxLength = Opt.MAXNWTCPScanLength
|
|
}
|
|
scanner = StreamScanner(pattern: data, maximumLength: maxLength)
|
|
scanning = true
|
|
readData()
|
|
}
|
|
|
|
private func queueCall(_ block: @escaping () -> Void) {
|
|
QueueFactory.getQueue().async(execute: block)
|
|
}
|
|
|
|
override public func observeValue(forKeyPath keyPath: String?, of object: Any?, change: [NSKeyValueChangeKey : Any]?, context: UnsafeMutableRawPointer?) {
|
|
guard keyPath == "state" else {
|
|
return
|
|
}
|
|
|
|
switch connection!.state {
|
|
case .connected:
|
|
queueCall {
|
|
self.delegate?.didConnectWith(socket: self)
|
|
}
|
|
case .disconnected:
|
|
cancelled = true
|
|
cancel()
|
|
case .cancelled:
|
|
cancelled = true
|
|
queueCall {
|
|
let delegate = self.delegate
|
|
self.delegate = nil
|
|
delegate?.didDisconnectWith(socket: self)
|
|
}
|
|
default:
|
|
break
|
|
}
|
|
}
|
|
|
|
private func readCallback(data: Data?) {
|
|
guard !cancelled else {
|
|
return
|
|
}
|
|
|
|
queueCall {
|
|
guard let data = self.consumeReadData(data) else {
|
|
// remote read is closed, but this is okay, nothing need to be done, if this socket is read again, then error occurs.
|
|
return
|
|
}
|
|
|
|
if self.scanning {
|
|
guard let (match, rest) = self.scanner.addAndScan(data) else {
|
|
self.readData()
|
|
return
|
|
}
|
|
|
|
self.scanner = nil
|
|
self.scanning = false
|
|
|
|
guard let matchData = match else {
|
|
// do not find match in the given length, stop now
|
|
return
|
|
}
|
|
|
|
self.readDataPrefix = rest
|
|
self.delegate?.didRead(data: matchData, from: self)
|
|
} else {
|
|
self.delegate?.didRead(data: data, from: self)
|
|
}
|
|
}
|
|
}
|
|
|
|
private func send(data: Data) {
|
|
writePending = true
|
|
self.connection!.write(data) { error in
|
|
self.queueCall {
|
|
self.writePending = false
|
|
|
|
guard error == nil else {
|
|
DDLogError("NWTCPSocket got an error when writing data: \(String(describing: error))")
|
|
self.disconnect()
|
|
return
|
|
}
|
|
|
|
self.delegate?.didWrite(data: data, by: self)
|
|
self.checkStatus()
|
|
}
|
|
}
|
|
}
|
|
|
|
private func consumeReadData(_ data: Data?) -> Data? {
|
|
defer {
|
|
readDataPrefix = nil
|
|
}
|
|
|
|
if readDataPrefix == nil {
|
|
return data
|
|
}
|
|
|
|
if data == nil {
|
|
return readDataPrefix
|
|
}
|
|
|
|
var wholeData = readDataPrefix!
|
|
wholeData.append(data!)
|
|
return wholeData
|
|
}
|
|
|
|
private func cancel() {
|
|
connection?.cancel()
|
|
}
|
|
|
|
private func checkStatus() {
|
|
if closeAfterWriting && !writePending {
|
|
cancel()
|
|
}
|
|
}
|
|
|
|
deinit {
|
|
guard let connection = connection else {
|
|
return
|
|
}
|
|
|
|
connection.removeObserver(self, forKeyPath: "state")
|
|
}
|
|
}
|