import Foundation import NetworkExtension /// The TCP socket build upon `NWTCPConnection`. /// /// - warning: This class is not thread-safe. public class NWTCPSocket: NSObject, RawTCPSocketProtocol { private var connection: NWTCPConnection? private var writePending = false private var closeAfterWriting = false private var cancelled = false private var scanner: StreamScanner! private var scanning: Bool = false private var readDataPrefix: Data? // MARK: RawTCPSocketProtocol implementation /// The `RawTCPSocketDelegate` instance. weak open var delegate: RawTCPSocketDelegate? /// If the socket is connected. public var isConnected: Bool { return connection != nil && connection!.state == .connected } /// The source address. /// /// - note: Always returns `nil`. public var sourceIPAddress: IPAddress? { return nil } /// The source port. /// /// - note: Always returns `nil`. public var sourcePort: Port? { return nil } /// The destination address. /// /// - note: Always returns `nil`. public var destinationIPAddress: IPAddress? { return nil } /// The destination port. /// /// - note: Always returns `nil`. public var destinationPort: Port? { return nil } /** Connect to remote host. - parameter host: Remote host. - parameter port: Remote port. - parameter enableTLS: Should TLS be enabled. - parameter tlsSettings: The settings of TLS. - throws: Never throws. */ public func connectTo(host: String, port: Int, enableTLS: Bool, tlsSettings: [AnyHashable: Any]?) throws { let endpoint = NWHostEndpoint(hostname: host, port: "\(port)") let tlsParameters = NWTLSParameters() if let tlsSettings = tlsSettings as? [String: AnyObject] { tlsParameters.setValuesForKeys(tlsSettings) } guard let connection = RawSocketFactory.TunnelProvider?.createTCPConnection(to: endpoint, enableTLS: enableTLS, tlsParameters: tlsParameters, delegate: nil) else { // This should only happen when the extension is already stopped and `RawSocketFactory.TunnelProvider` is set to `nil`. return } self.connection = connection connection.addObserver(self, forKeyPath: "state", options: [.initial, .new], context: nil) } /** Disconnect the socket. The socket will disconnect elegantly after any queued writing data are successfully sent. */ public func disconnect() { cancelled = true if connection == nil || connection!.state == .cancelled { delegate?.didDisconnectWith(socket: self) } else { closeAfterWriting = true checkStatus() } } /** Disconnect the socket immediately. */ public func forceDisconnect() { cancelled = true if connection == nil || connection!.state == .cancelled { delegate?.didDisconnectWith(socket: self) } else { cancel() } } /** Send data to remote. - parameter data: Data to send. - warning: This should only be called after the last write is finished, i.e., `delegate?.didWriteData()` is called. */ public func write(data: Data) { guard !cancelled else { return } guard data.count > 0 else { QueueFactory.getQueue().async { self.delegate?.didWrite(data: data, by: self) } return } send(data: data) } /** Read data from the socket. - warning: This should only be called after the last read is finished, i.e., `delegate?.didReadData()` is called. */ public func readData() { guard !cancelled else { return } connection!.readMinimumLength(1, maximumLength: Opt.MAXNWTCPSocketReadDataSize) { data, error in guard error == nil else { let e = error! as NSError let ignore = ( e.domain == "kNWErrorDomainPOSIX" && e.code == POSIXError.ECANCELED.rawValue // Operation canceled || e.domain == NSPOSIXErrorDomain && e.code == POSIXError.ENOTCONN.rawValue // Socket is not connected ) if !ignore { DDLogError("NWTCPSocket got an error when reading data: \(String(describing: error))") } self.queueCall { self.disconnect() } return } self.readCallback(data: data) } } /** Read specific length of data from the socket. - parameter length: The length of the data to read. - warning: This should only be called after the last read is finished, i.e., `delegate?.didReadData()` is called. */ public func readDataTo(length: Int) { guard !cancelled else { return } connection!.readLength(length) { data, error in guard error == nil else { DDLogError("NWTCPSocket got an error when reading data: \(String(describing: error))") self.queueCall { self.disconnect() } return } self.readCallback(data: data) } } /** Read data until a specific pattern (including the pattern). - parameter data: The pattern. - warning: This should only be called after the last read is finished, i.e., `delegate?.didReadData()` is called. */ public func readDataTo(data: Data) { readDataTo(data: data, maxLength: 0) } // Actually, this method is available as `- (void)readToPattern:(id)arg1 maximumLength:(unsigned int)arg2 completionHandler:(id /* block */)arg3;` // which is sadly not available in public header for some reason I don't know. // I don't want to do it myself since This method is not trival to implement and I don't like reinventing the wheel. // Here is only the most naive version, which may not be the optimal if using with large data blocks. /** Read data until a specific pattern (including the pattern). - parameter data: The pattern. - parameter maxLength: The max length of data to scan for the pattern. - warning: This should only be called after the last read is finished, i.e., `delegate?.didReadData()` is called. */ public func readDataTo(data: Data, maxLength: Int) { guard !cancelled else { return } var maxLength = maxLength if maxLength == 0 { maxLength = Opt.MAXNWTCPScanLength } scanner = StreamScanner(pattern: data, maximumLength: maxLength) scanning = true readData() } private func queueCall(_ block: @escaping () -> Void) { QueueFactory.getQueue().async(execute: block) } override public func observeValue(forKeyPath keyPath: String?, of object: Any?, change: [NSKeyValueChangeKey : Any]?, context: UnsafeMutableRawPointer?) { guard keyPath == "state" else { return } switch connection!.state { case .connected: queueCall { self.delegate?.didConnectWith(socket: self) } case .disconnected: cancelled = true cancel() case .cancelled: cancelled = true queueCall { let delegate = self.delegate self.delegate = nil delegate?.didDisconnectWith(socket: self) } default: break } } private func readCallback(data: Data?) { guard !cancelled else { return } queueCall { guard let data = self.consumeReadData(data) else { // remote read is closed, but this is okay, nothing need to be done, if this socket is read again, then error occurs. return } if self.scanning { guard let (match, rest) = self.scanner.addAndScan(data) else { self.readData() return } self.scanner = nil self.scanning = false guard let matchData = match else { // do not find match in the given length, stop now return } self.readDataPrefix = rest self.delegate?.didRead(data: matchData, from: self) } else { self.delegate?.didRead(data: data, from: self) } } } private func send(data: Data) { writePending = true self.connection!.write(data) { error in self.queueCall { self.writePending = false guard error == nil else { DDLogError("NWTCPSocket got an error when writing data: \(String(describing: error))") self.disconnect() return } self.delegate?.didWrite(data: data, by: self) self.checkStatus() } } } private func consumeReadData(_ data: Data?) -> Data? { defer { readDataPrefix = nil } if readDataPrefix == nil { return data } if data == nil { return readDataPrefix } var wholeData = readDataPrefix! wholeData.append(data!) return wholeData } private func cancel() { connection?.cancel() } private func checkStatus() { if closeAfterWriting && !writePending { cancel() } } deinit { guard let connection = connection else { return } connection.removeObserver(self, forKeyPath: "state") } }