import Foundation import SQLite3 typealias Timestamp = Int64 struct GroupedDomain { let domain: String, total: Int32, blocked: Int32, lastModified: Timestamp var options: FilterOptions? = nil } struct FilterOptions: OptionSet { let rawValue: Int32 static let none = FilterOptions(rawValue: 0) static let blocked = FilterOptions(rawValue: 1 << 0) static let ignored = FilterOptions(rawValue: 1 << 1) static let any = FilterOptions(rawValue: 0b11) } enum SQLiteError: Error { case OpenDatabase(message: String) case Prepare(message: String) case Step(message: String) case Bind(message: String) } // MARK: - SQLiteDatabase class SQLiteDatabase { private let dbPointer: OpaquePointer? private init(dbPointer: OpaquePointer?) { // print("SQLite path: \(URL.internalDB())") self.dbPointer = dbPointer } fileprivate var errorMessage: String { if let errorPointer = sqlite3_errmsg(dbPointer) { let errorMessage = String(cString: errorPointer) return errorMessage } else { return "No error message provided from sqlite." } } deinit { sqlite3_close(dbPointer) } static func destroyDatabase(path: String = URL.internalDB().relativePath) { if FileManager.default.fileExists(atPath: path) { do { try FileManager.default.removeItem(atPath: path) } catch { print("Could not destroy database file: \(path)") } } } // static func export() throws -> URL { // let fmt = DateFormatter() // fmt.dateFormat = "yyyy-MM-dd" // let dest = FileManager.default.exportDir().appendingPathComponent("\(fmt.string(from: Date()))-dns-log.sqlite") // try? FileManager.default.removeItem(at: dest) // try FileManager.default.copyItem(at: FileManager.default.internalDB(), to: dest) // return dest // } static func open(path: String = URL.internalDB().relativePath) throws -> SQLiteDatabase { var db: OpaquePointer? //sqlite3_open_v2(path, &db, SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE | SQLITE_OPEN_SHAREDCACHE, nil) if sqlite3_open(path, &db) == SQLITE_OK { return SQLiteDatabase(dbPointer: db) } else { defer { if db != nil { sqlite3_close(db) } } if let errorPointer = sqlite3_errmsg(db) { let message = String(cString: errorPointer) throw SQLiteError.OpenDatabase(message: message) } else { throw SQLiteError.OpenDatabase(message: "No error message provided from sqlite.") } } } func run(sql: String, bind: ((OpaquePointer) -> Bool)?, step: (OpaquePointer) throws -> T) throws -> T { var statement: OpaquePointer? guard sqlite3_prepare_v2(dbPointer, sql, -1, &statement, nil) == SQLITE_OK, let stmt = statement else { throw SQLiteError.Prepare(message: errorMessage) } defer { sqlite3_finalize(stmt) } guard bind?(stmt) ?? true else { throw SQLiteError.Bind(message: errorMessage) } return try step(stmt) } func ifStep(_ stmt: OpaquePointer, _ expected: Int32) throws { guard sqlite3_step(stmt) == expected else { throw SQLiteError.Step(message: errorMessage) } } func createTable(table: SQLTable.Type) throws { try run(sql: table.createStatement, bind: nil) { try ifStep($0, SQLITE_DONE) } } func vacuum() { try? run(sql: "VACUUM;", bind: nil) { try ifStep($0, SQLITE_DONE) } } } protocol SQLTable { static var createStatement: String { get } } // MARK: - Easy Access func private extension SQLiteDatabase { func bindInt(_ stmt: OpaquePointer, _ col: Int32, _ value: Int32) -> Bool { sqlite3_bind_int(stmt, col, value) == SQLITE_OK } func bindInt64(_ stmt: OpaquePointer, _ col: Int32, _ value: sqlite3_int64) -> Bool { sqlite3_bind_int64(stmt, col, value) == SQLITE_OK } func bindText(_ stmt: OpaquePointer, _ col: Int32, _ value: String) -> Bool { sqlite3_bind_text(stmt, col, (value as NSString).utf8String, -1, nil) == SQLITE_OK } func bindTextOrNil(_ stmt: OpaquePointer, _ col: Int32, _ value: String?) -> Bool { sqlite3_bind_text(stmt, col, (value == nil) ? nil : (value! as NSString).utf8String, -1, nil) == SQLITE_OK } func readText(_ stmt: OpaquePointer, _ col: Int32) -> String? { let val = sqlite3_column_text(stmt, col) return (val != nil ? String(cString: val!) : nil) } func allRows(_ stmt: OpaquePointer, _ fn: (OpaquePointer) -> T) -> [T] { var r: [T] = [] while (sqlite3_step(stmt) == SQLITE_ROW) { r.append(fn(stmt)) } return r } func allRowsKeyed(_ stmt: OpaquePointer, _ fn: (OpaquePointer) -> (key: T, value: U)) -> [T:U] { var r: [T:U] = [:] while (sqlite3_step(stmt) == SQLITE_ROW) { let (k,v) = fn(stmt); r[k] = v } return r } } extension SQLiteDatabase { func initScheme() { try? self.createTable(table: DNSQueryT.self) try? self.createTable(table: DNSFilterT.self) try? self.createTable(table: Recording.self) } } // MARK: - DNSQueryT private struct DNSQueryT: SQLTable { let ts: Timestamp let domain: String let wasBlocked: Bool let options: FilterOptions static var createStatement: String { return """ CREATE TABLE IF NOT EXISTS req( ts BIGINT DEFAULT (strftime('%s','now')), domain VARCHAR(255) NOT NULL, logOpt INT DEFAULT 0 ); """ } } extension SQLiteDatabase { // MARK: insert func insertDNSQuery(_ domain: String, blocked: Bool) throws { try? run(sql: "INSERT INTO req (domain, logOpt) VALUES (?, ?);", bind: { self.bindText($0, 1, domain) && self.bindInt($0, 2, blocked ? 1 : 0) }) { try ifStep($0, SQLITE_DONE) } } // MARK: delete func destroyContent() throws { try? run(sql: "DROP TABLE IF EXISTS req;", bind: nil) { try ifStep($0, SQLITE_DONE) } try? createTable(table: DNSQueryT.self) } /// Delete rows matching `ts >= ? AND "domain" OR "*.domain"` @discardableResult func deleteRows(matching domain: String, since ts: Timestamp = 0) throws -> Int32 { try run(sql: "DELETE FROM req WHERE ts >= ? AND (domain = ? OR domain LIKE '%.' || ?);", bind: { self.bindInt64($0, 1, ts) && self.bindText($0, 2, domain) && self.bindText($0, 3, domain) }) { stmt -> Int32 in try ifStep(stmt, SQLITE_DONE) return sqlite3_changes(dbPointer) } } // MARK: read func readGroupedDomain(_ stmt: OpaquePointer) -> GroupedDomain { GroupedDomain(domain: readText(stmt, 0) ?? "", total: sqlite3_column_int(stmt, 1), blocked: sqlite3_column_int(stmt, 2), lastModified: sqlite3_column_int64(stmt, 3)) } func domainList(since ts: Timestamp = 0) -> [GroupedDomain]? { try? run(sql: "SELECT domain, COUNT(*), SUM(logOpt&1), MAX(ts) FROM req \(ts == 0 ? "" : "WHERE ts > ?") GROUP BY domain ORDER BY 4 DESC;", bind: { ts == 0 || self.bindInt64($0, 1, ts) }) { allRows($0) { readGroupedDomain($0) } } } func domainList(matching domain: String) -> [GroupedDomain]? { try? run(sql: "SELECT domain, COUNT(*), SUM(logOpt&1), MAX(ts) FROM req WHERE (domain = ? OR domain LIKE '%.' || ?) GROUP BY domain ORDER BY 4 DESC;", bind: { self.bindText($0, 1, domain) && self.bindText($0, 2, domain) }) { allRows($0) { readGroupedDomain($0) } } } func timesForDomain(_ fullDomain: String) -> [(Timestamp, Bool)]? { try? run(sql: "SELECT ts, logOpt FROM req WHERE domain = ?;", bind: { self.bindText($0, 1, fullDomain) }) { allRows($0) { (sqlite3_column_int64($0, 0), sqlite3_column_int($0, 1) > 0) } } } } // MARK: - DNSFilterT private struct DNSFilterT: SQLTable { let domain: String let options: FilterOptions static var createStatement: String { return """ CREATE TABLE IF NOT EXISTS filter( domain VARCHAR(255) UNIQUE NOT NULL, opt INT DEFAULT 0 ); """ } } extension SQLiteDatabase { // MARK: read func loadFilters() -> [String : FilterOptions]? { try? run(sql: "SELECT domain, opt FROM filter ORDER BY domain ASC;", bind: nil) { allRowsKeyed($0) { (key: readText($0, 0) ?? "", value: FilterOptions(rawValue: sqlite3_column_int($0, 1))) } } } // MARK: write func setFilter(_ domain: String, _ value: FilterOptions?) { func removeFilter() { try? run(sql: "DELETE FROM filter WHERE domain = ? LIMIT 1;", bind: { self.bindText($0, 1, domain) }) { stmt -> Void in sqlite3_step(stmt) } } guard let rv = value?.rawValue, rv > 0 else { removeFilter() return } func createFilter() throws { try run(sql: "INSERT OR FAIL INTO filter (domain, opt) VALUES (?, ?);", bind: { self.bindText($0, 1, domain) && self.bindInt($0, 2, rv) }) { try ifStep($0, SQLITE_DONE) } } func updateFilter() { try? run(sql: "UPDATE filter SET opt = ? WHERE domain = ? LIMIT 1;", bind: { self.bindInt($0, 1, rv) && self.bindText($0, 2, domain) }) { stmt -> Void in sqlite3_step(stmt) } } do { try createFilter() } catch { updateFilter() } } } // MARK: - Recordings struct Recording: SQLTable { let start: Timestamp let stop: Timestamp? var appId: String? = nil var title: String? = nil var notes: String? = nil static var createStatement: String { return """ CREATE TABLE IF NOT EXISTS rec( start BIGINT DEFAULT (strftime('%s','now')), stop BIGINT, appid VARCHAR(255), title VARCHAR(255), notes TEXT ); """ } } extension SQLiteDatabase { // MARK: write func startNewRecording(_ title: String? = nil, appBundle: String? = nil) throws -> Recording { try run(sql: "INSERT INTO rec (title, appid) VALUES (?, ?);", bind: { self.bindTextOrNil($0, 1, title) && self.bindTextOrNil($0, 2, appBundle) }) { stmt -> Recording in try ifStep(stmt, SQLITE_DONE) return ongoingRecording()! } } func stopRecordings() { try? run(sql: "UPDATE rec SET stop = (strftime('%s','now')) WHERE stop IS NULL;", bind: nil) { stmt -> Void in sqlite3_step(stmt) } } func updateRecording(_ r: Recording) { try? run(sql: "UPDATE rec SET title = ?, appid = ?, notes = ? WHERE start = ? LIMIT 1;", bind: { self.bindTextOrNil($0, 1, r.title) && self.bindTextOrNil($0, 2, r.appId) && self.bindTextOrNil($0, 3, r.notes) && self.bindInt64($0, 4, r.start) }) { stmt -> Void in sqlite3_step(stmt) } } func deleteRecording(_ r: Recording) throws -> Bool { try run(sql: "DELETE FROM rec WHERE start = ? LIMIT 1;", bind: { self.bindInt64($0, 1, r.start) }) { try ifStep($0, SQLITE_DONE) return sqlite3_changes(dbPointer) > 0 } } // MARK: read func readRecording(_ stmt: OpaquePointer) -> Recording { let end = sqlite3_column_int64(stmt, 1) return Recording(start: sqlite3_column_int64(stmt, 0), stop: end == 0 ? nil : end, appId: readText(stmt, 2), title: readText(stmt, 3), notes: readText(stmt, 4)) } func ongoingRecording() -> Recording? { try? run(sql: "SELECT * FROM rec WHERE stop IS NULL LIMIT 1;", bind: nil) { try ifStep($0, SQLITE_ROW) return readRecording($0) } } func allRecordings() -> [Recording]? { try? run(sql: "SELECT * FROM rec WHERE stop IS NOT NULL;", bind: nil) { allRows($0) { readRecording($0) } } } }