diff --git a/GlassVPN/PacketTunnelProvider.swift b/GlassVPN/PacketTunnelProvider.swift index b584ee8..e8288c3 100644 --- a/GlassVPN/PacketTunnelProvider.swift +++ b/GlassVPN/PacketTunnelProvider.swift @@ -1,7 +1,5 @@ import NetworkExtension -fileprivate var db: SQLiteDatabase! -fileprivate var pStmt: OpaquePointer! fileprivate var filterDomains: [String]! fileprivate var filterOptions: [(block: Bool, ignore: Bool)]! @@ -9,7 +7,7 @@ fileprivate var filterOptions: [(block: Bool, ignore: Bool)]! // MARK: Backward DNS Binary Tree Lookup fileprivate func reloadDomainFilter() { - let tmp = db.loadFilters()?.map({ + let tmp = AppDB?.loadFilters()?.map({ (String($0.reversed()), $1) }).sorted(by: { $0.0 < $1.0 }) ?? [] filterDomains = tmp.map { $0.0 } @@ -35,6 +33,18 @@ fileprivate func filterIndex(for domain: String) -> Int { return -1 } +private let queue = DispatchQueue.init(label: "PSIGlassDNSQueue", qos: .userInteractive, target: .main) + +private func logAsync(_ domain: String, blocked: Bool) { + queue.async { + do { + try AppDB?.logWrite(domain, blocked: blocked) + } catch { + DDLogWarn("Couldn't write: \(error)") + } + } +} + // MARK: ObserverFactory @@ -52,11 +62,11 @@ class LDObserverFactory: ObserverFactory { let i = filterIndex(for: session.host) if i >= 0 { let (block, ignore) = filterOptions[i] - if !ignore { try? db.logWrite(pStmt, session.host, blocked: block) } + if !ignore { logAsync(session.host, blocked: block) } if block { socket.forceDisconnect() } } else { // TODO: disable filter during recordings - try? db.logWrite(pStmt, session.host) + logAsync(session.host, blocked: false) } default: break @@ -76,9 +86,7 @@ class PacketTunnelProvider: NEPacketTunnelProvider { override func startTunnel(options: [String : NSObject]?, completionHandler: @escaping (Error?) -> Void) { do { - db = try SQLiteDatabase.open() - db.initCommonScheme() - pStmt = try db.logWritePrepare() + try SQLiteDatabase.open().initCommonScheme() } catch { completionHandler(error) return @@ -135,9 +143,6 @@ class PacketTunnelProvider: NEPacketTunnelProvider { ObserverFactory.currentFactory = nil proxyServer.stop() proxyServer = nil - db.prepared(finalize: pStmt) - pStmt = nil - db = nil filterDomains = nil filterOptions = nil completionHandler() diff --git a/main/DB/DBCommon.swift b/main/DB/DBCommon.swift index 38e939a..0bb9244 100644 --- a/main/DB/DBCommon.swift +++ b/main/DB/DBCommon.swift @@ -25,16 +25,22 @@ extension CreateTable { } extension SQLiteDatabase { +// /// `INSERT INTO cache (dns, opt) VALUES (?, ?);` +// func logWritePrepare() throws -> OpaquePointer { +// try prepare(sql: "INSERT INTO cache (dns, opt) VALUES (?, ?);") +// } +// /// `prep` must exist and be initialized with `logWritePrepare()` +// func logWrite(_ pStmt: OpaquePointer!, _ domain: String, blocked: Bool = false) throws { +// guard let prep = pStmt else { +// return +// } +// try prepared(run: prep, bind: [BindText(domain), BindInt32(blocked ? 1 : 0)]) +// } /// `INSERT INTO cache (dns, opt) VALUES (?, ?);` - func logWritePrepare() throws -> OpaquePointer { - try prepare(sql: "INSERT INTO cache (dns, opt) VALUES (?, ?);") - } - /// `prep` must exist and be initialized with `logWritePrepare()` - func logWrite(_ pStmt: OpaquePointer!, _ domain: String, blocked: Bool = false) throws { - guard let prep = pStmt else { - return - } - try prepared(run: prep, bind: [BindText(domain), BindInt32(blocked ? 1 : 0)]) + func logWrite(_ domain: String, blocked: Bool = false) throws { + try self.run(sql: "INSERT INTO cache (dns, opt) VALUES (?, ?);", + bind: [BindText(domain), BindInt32(blocked ? 1 : 0)]) + { try ifStep($0, SQLITE_DONE) } } } @@ -52,10 +58,10 @@ extension CreateTable { } struct FilterOptions: OptionSet { - let rawValue: Int32 + let rawValue: Int32 static let none = FilterOptions([]) - static let blocked = FilterOptions(rawValue: 1 << 0) - static let ignored = FilterOptions(rawValue: 1 << 1) + static let blocked = FilterOptions(rawValue: 1 << 0) + static let ignored = FilterOptions(rawValue: 1 << 1) static let any = FilterOptions(rawValue: 0b11) } diff --git a/main/DB/DBCore.swift b/main/DB/DBCore.swift index a4ad215..8332033 100644 --- a/main/DB/DBCore.swift +++ b/main/DB/DBCore.swift @@ -35,7 +35,7 @@ class SQLiteDatabase { } deinit { - sqlite3_close(dbPointer) + sqlite3_close_v2(dbPointer) } static func destroyDatabase(path: String = URL.internalDB().relativePath) { @@ -47,15 +47,10 @@ class SQLiteDatabase { static func open(path: String = URL.internalDB().relativePath) throws -> SQLiteDatabase { var db: OpaquePointer? - //sqlite3_open_v2(path, &db, SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE | SQLITE_OPEN_SHAREDCACHE, nil) - if sqlite3_open(path, &db) == SQLITE_OK { + if sqlite3_open_v2(path, &db, SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE | SQLITE_OPEN_FULLMUTEX, nil) == SQLITE_OK { return SQLiteDatabase(dbPointer: db) } else { - defer { - if db != nil { - sqlite3_close(db) - } - } + defer { sqlite3_close_v2(db) } if let errorPointer = sqlite3_errmsg(db) { let message = String(cString: errorPointer) throw SQLiteError.OpenDatabase(message: message) @@ -222,6 +217,7 @@ extension SQLiteDatabase { func prepare(sql: String) throws -> OpaquePointer { var pStmt: OpaquePointer? guard sqlite3_prepare_v2(dbPointer, sql, -1, &pStmt, nil) == SQLITE_OK, let S = pStmt else { + sqlite3_finalize(pStmt) throw SQLiteError.Prepare(message: errorMessage) } return S diff --git a/main/Data Source/TestDataSource.swift b/main/Data Source/TestDataSource.swift index bcfd84f..cf2f6b4 100644 --- a/main/Data Source/TestDataSource.swift +++ b/main/Data Source/TestDataSource.swift @@ -2,25 +2,22 @@ import Foundation #if IOS_SIMULATOR -private let db = AppDB! -private var pStmt: OpaquePointer? - class TestDataSource { static func load() { QLog.Debug("SQLite path: \(URL.internalDB())") + let db = AppDB! let deleted = db.dnsLogsDelete("test.com", strict: false) try? db.run(sql: "DELETE FROM cache;") QLog.Debug("Deleting \(deleted) rows matching 'test.com' (+ \(db.numberOfChanges) in cache)") QLog.Debug("Writing 33 test logs") - pStmt = try! db.logWritePrepare() - try? db.logWrite(pStmt, "keeptest.com", blocked: false) - for _ in 1...4 { try? db.logWrite(pStmt, "test.com", blocked: false) } - for _ in 1...7 { try? db.logWrite(pStmt, "i.test.com", blocked: false) } - for i in 1...8 { try? db.logWrite(pStmt, "b.test.com", blocked: i>5) } - for i in 1...13 { try? db.logWrite(pStmt, "bi.test.com", blocked: i%2==0) } + try? db.logWrite("keeptest.com", blocked: false) + for _ in 1...4 { try? db.logWrite("test.com", blocked: false) } + for _ in 1...7 { try? db.logWrite("i.test.com", blocked: false) } + for i in 1...8 { try? db.logWrite("b.test.com", blocked: i>5) } + for i in 1...13 { try? db.logWrite("bi.test.com", blocked: i%2==0) } db.dnsLogsPersist() @@ -36,7 +33,7 @@ class TestDataSource { @objc static func insertRandom() { //QLog.Debug("Inserting 1 periodic log entry") - try? db.logWrite(pStmt, "\(arc4random() % 5).count.test.com", blocked: true) + try? AppDB?.logWrite("\(arc4random() % 5).count.test.com", blocked: true) } } #endif