Fixed a number of bugs, added in logic to handle id blinding being enabled and migrated session SOGS IPs to domains

Added logic to handle the case when ID blinded gets switched on server-side and the app already has open groups with cached capabilities
Added logic to migrate users from using HTTP and IP-based session open groups to use the HTTPS domain name url instead
Fixed a bug with the PushNotificationAPI update registration response structure
Fixed some broken unit tests (and a bug which was introduced in an earlier optimisation)
Fixed a bug where trusting a contact (to download their messages) wouldn't trigger the message UI to update
Fixed a bug where tapping a push notification wasn't opening the associated thread when the app isn't running in the background
pull/612/head
Morgan Pretty 3 years ago
parent ff2d96e0d5
commit 76f7e4e246

@ -742,7 +742,7 @@
FDC290A927D9B46D005DAE71 /* NimbleExtensions.swift in Sources */ = {isa = PBXBuildFile; fileRef = FDC290A727D9B46D005DAE71 /* NimbleExtensions.swift */; }; FDC290A927D9B46D005DAE71 /* NimbleExtensions.swift in Sources */ = {isa = PBXBuildFile; fileRef = FDC290A727D9B46D005DAE71 /* NimbleExtensions.swift */; };
FDC290AA27D9B6FD005DAE71 /* Mock.swift in Sources */ = {isa = PBXBuildFile; fileRef = FDC290A527D860CE005DAE71 /* Mock.swift */; }; FDC290AA27D9B6FD005DAE71 /* Mock.swift in Sources */ = {isa = PBXBuildFile; fileRef = FDC290A527D860CE005DAE71 /* Mock.swift */; };
FDC290B327DFF9F5005DAE71 /* TestOnionRequestAPI.swift in Sources */ = {isa = PBXBuildFile; fileRef = FDC290B227DFF9F5005DAE71 /* TestOnionRequestAPI.swift */; }; FDC290B327DFF9F5005DAE71 /* TestOnionRequestAPI.swift in Sources */ = {isa = PBXBuildFile; fileRef = FDC290B227DFF9F5005DAE71 /* TestOnionRequestAPI.swift */; };
FDC4380927B31D4E00C60D73 /* SOGSError.swift in Sources */ = {isa = PBXBuildFile; fileRef = FDC4380827B31D4E00C60D73 /* SOGSError.swift */; }; FDC4380927B31D4E00C60D73 /* OpenGroupAPIError.swift in Sources */ = {isa = PBXBuildFile; fileRef = FDC4380827B31D4E00C60D73 /* OpenGroupAPIError.swift */; };
FDC4381527B329CE00C60D73 /* NonceGenerator.swift in Sources */ = {isa = PBXBuildFile; fileRef = FDC4381427B329CE00C60D73 /* NonceGenerator.swift */; }; FDC4381527B329CE00C60D73 /* NonceGenerator.swift in Sources */ = {isa = PBXBuildFile; fileRef = FDC4381427B329CE00C60D73 /* NonceGenerator.swift */; };
FDC4381727B32EC700C60D73 /* Personalization.swift in Sources */ = {isa = PBXBuildFile; fileRef = FDC4381627B32EC700C60D73 /* Personalization.swift */; }; FDC4381727B32EC700C60D73 /* Personalization.swift in Sources */ = {isa = PBXBuildFile; fileRef = FDC4381627B32EC700C60D73 /* Personalization.swift */; };
FDC4382027B36ADC00C60D73 /* SOGSEndpoint.swift in Sources */ = {isa = PBXBuildFile; fileRef = FDC4381F27B36ADC00C60D73 /* SOGSEndpoint.swift */; }; FDC4382027B36ADC00C60D73 /* SOGSEndpoint.swift in Sources */ = {isa = PBXBuildFile; fileRef = FDC4381F27B36ADC00C60D73 /* SOGSEndpoint.swift */; };
@ -1778,7 +1778,7 @@
FDC290A527D860CE005DAE71 /* Mock.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Mock.swift; sourceTree = "<group>"; }; FDC290A527D860CE005DAE71 /* Mock.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Mock.swift; sourceTree = "<group>"; };
FDC290A727D9B46D005DAE71 /* NimbleExtensions.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = NimbleExtensions.swift; sourceTree = "<group>"; }; FDC290A727D9B46D005DAE71 /* NimbleExtensions.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = NimbleExtensions.swift; sourceTree = "<group>"; };
FDC290B227DFF9F5005DAE71 /* TestOnionRequestAPI.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = TestOnionRequestAPI.swift; sourceTree = "<group>"; }; FDC290B227DFF9F5005DAE71 /* TestOnionRequestAPI.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = TestOnionRequestAPI.swift; sourceTree = "<group>"; };
FDC4380827B31D4E00C60D73 /* SOGSError.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = SOGSError.swift; sourceTree = "<group>"; }; FDC4380827B31D4E00C60D73 /* OpenGroupAPIError.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = OpenGroupAPIError.swift; sourceTree = "<group>"; };
FDC4381427B329CE00C60D73 /* NonceGenerator.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = NonceGenerator.swift; sourceTree = "<group>"; }; FDC4381427B329CE00C60D73 /* NonceGenerator.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = NonceGenerator.swift; sourceTree = "<group>"; };
FDC4381627B32EC700C60D73 /* Personalization.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Personalization.swift; sourceTree = "<group>"; }; FDC4381627B32EC700C60D73 /* Personalization.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Personalization.swift; sourceTree = "<group>"; };
FDC4381F27B36ADC00C60D73 /* SOGSEndpoint.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = SOGSEndpoint.swift; sourceTree = "<group>"; }; FDC4381F27B36ADC00C60D73 /* SOGSEndpoint.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = SOGSEndpoint.swift; sourceTree = "<group>"; };
@ -3756,7 +3756,7 @@
isa = PBXGroup; isa = PBXGroup;
children = ( children = (
FDC4381F27B36ADC00C60D73 /* SOGSEndpoint.swift */, FDC4381F27B36ADC00C60D73 /* SOGSEndpoint.swift */,
FDC4380827B31D4E00C60D73 /* SOGSError.swift */, FDC4380827B31D4E00C60D73 /* OpenGroupAPIError.swift */,
FDC4381627B32EC700C60D73 /* Personalization.swift */, FDC4381627B32EC700C60D73 /* Personalization.swift */,
FDC4381427B329CE00C60D73 /* NonceGenerator.swift */, FDC4381427B329CE00C60D73 /* NonceGenerator.swift */,
FDC438C227BB512200C60D73 /* SodiumProtocols.swift */, FDC438C227BB512200C60D73 /* SodiumProtocols.swift */,
@ -5191,7 +5191,7 @@
FD716E682850318E00C96BF4 /* CallMode.swift in Sources */, FD716E682850318E00C96BF4 /* CallMode.swift in Sources */,
FD09799527FE7B8E00936362 /* Interaction.swift in Sources */, FD09799527FE7B8E00936362 /* Interaction.swift in Sources */,
FD5C72FF284F0F120029977D /* MessageReceiver+ConfigurationMessages.swift in Sources */, FD5C72FF284F0F120029977D /* MessageReceiver+ConfigurationMessages.swift in Sources */,
FDC4380927B31D4E00C60D73 /* SOGSError.swift in Sources */, FDC4380927B31D4E00C60D73 /* OpenGroupAPIError.swift in Sources */,
FDC4382027B36ADC00C60D73 /* SOGSEndpoint.swift in Sources */, FDC4382027B36ADC00C60D73 /* SOGSEndpoint.swift in Sources */,
FDC438C927BB706500C60D73 /* SendDirectMessageRequest.swift in Sources */, FDC438C927BB706500C60D73 /* SendDirectMessageRequest.swift in Sources */,
C3A71D1F25589AC30043A11F /* WebSocketResources.pb.swift in Sources */, C3A71D1F25589AC30043A11F /* WebSocketResources.pb.swift in Sources */,

@ -169,6 +169,16 @@ public class ConversationViewModel: OWSAudioPlayerDelegate {
columns: Interaction.Columns columns: Interaction.Columns
.allCases .allCases
.filter { $0 != .wasRead } .filter { $0 != .wasRead }
),
PagedData.ObservedChanges(
table: Contact.self,
columns: [.isTrusted],
joinToPagedType: {
let interaction: TypedTableAlias<Interaction> = TypedTableAlias()
let contact: TypedTableAlias<Contact> = TypedTableAlias()
return SQL("LEFT JOIN \(Contact.self) ON \(contact[.id]) = \(interaction[.threadId])")
}()
) )
], ],
filterSQL: MessageViewModel.filterSQL(threadId: threadId), filterSQL: MessageViewModel.filterSQL(threadId: threadId),
@ -189,7 +199,6 @@ public class ConversationViewModel: OWSAudioPlayerDelegate {
], ],
dataQuery: MessageViewModel.AttachmentInteractionInfo.baseQuery, dataQuery: MessageViewModel.AttachmentInteractionInfo.baseQuery,
joinToPagedType: MessageViewModel.AttachmentInteractionInfo.joinToViewModelQuerySQL, joinToPagedType: MessageViewModel.AttachmentInteractionInfo.joinToViewModelQuerySQL,
groupPagedType: MessageViewModel.AttachmentInteractionInfo.groupViewModelQuerySQL,
associateData: MessageViewModel.AttachmentInteractionInfo.createAssociateDataClosure() associateData: MessageViewModel.AttachmentInteractionInfo.createAssociateDataClosure()
), ),
AssociatedRecord<MessageViewModel.TypingIndicatorInfo, MessageViewModel>( AssociatedRecord<MessageViewModel.TypingIndicatorInfo, MessageViewModel>(

@ -659,7 +659,7 @@ final class HomeVC: BaseVC, UITableViewDataSource, UITableViewDelegate, NewConve
} }
let conversationVC: ConversationVC = ConversationVC(threadId: threadId, threadVariant: variant, focusedInteractionId: focusedInteractionId) let conversationVC: ConversationVC = ConversationVC(threadId: threadId, threadVariant: variant, focusedInteractionId: focusedInteractionId)
self.navigationController?.setViewControllers([ self, conversationVC ], animated: true) self.navigationController?.setViewControllers([ self, conversationVC ], animated: animated)
} }
@objc private func openSettings() { @objc private func openSettings() {

@ -191,6 +191,13 @@ class AppDelegate: UIResponder, UIApplicationDelegate, UNUserNotificationCenterD
// Trigger any launch-specific jobs and start the JobRunner // Trigger any launch-specific jobs and start the JobRunner
JobRunner.appDidFinishLaunching() JobRunner.appDidFinishLaunching()
/// Setup the UI
///
/// **Note:** This **MUST** be run before calling `AppReadiness.setAppIsReady()` otherwise if
/// we are launching the app from a push notification the HomeVC won't be setup yet and it won't open the
/// related thread
self.ensureRootViewController(isPreAppReadyCall: true)
// Note that this does much more than set a flag; // Note that this does much more than set a flag;
// it will also run all deferred blocks (including the JobRunner // it will also run all deferred blocks (including the JobRunner
// 'appDidBecomeActive' method) // 'appDidBecomeActive' method)
@ -220,9 +227,6 @@ class AppDelegate: UIResponder, UIApplicationDelegate, UNUserNotificationCenterD
} }
} }
} }
// Setup the UI
self.ensureRootViewController()
} }
private func showFailedMigrationAlert() { private func showFailedMigrationAlert() {
@ -321,8 +325,8 @@ class AppDelegate: UIResponder, UIApplicationDelegate, UNUserNotificationCenterD
} }
} }
private func ensureRootViewController() { private func ensureRootViewController(isPreAppReadyCall: Bool = false) {
guard AppReadiness.isAppReady() && GRDBStorage.shared.isValid && !hasInitialRootViewController else { guard (AppReadiness.isAppReady() || isPreAppReadyCall) && GRDBStorage.shared.isValid && !hasInitialRootViewController else {
return return
} }
@ -334,6 +338,13 @@ class AppDelegate: UIResponder, UIApplicationDelegate, UNUserNotificationCenterD
) )
) )
UIViewController.attemptRotationToDeviceOrientation() UIViewController.attemptRotationToDeviceOrientation()
/// **Note:** There is an annoying case when starting the app by interacting with a push notification where
/// the `HomeVC` won't have completed loading it's view which means the `SessionApp.homeViewController`
/// won't have been set - we set the value directly here to resolve this edge case
if let homeViewController: HomeVC = (self.window?.rootViewController as? UINavigationController)?.viewControllers.first as? HomeVC {
SessionApp.homeViewController.mutate { $0 = homeViewController }
}
} }
// MARK: - Notifications // MARK: - Notifications

@ -491,7 +491,7 @@ class NotificationActionHandler {
// If this happens when the the app is not, visible we skip the animation so the thread // If this happens when the the app is not, visible we skip the animation so the thread
// can be visible to the user immediately upon opening the app, rather than having to watch // can be visible to the user immediately upon opening the app, rather than having to watch
// it animate in from the homescreen. // it animate in from the homescreen.
let shouldAnimate = UIApplication.shared.applicationState == .active let shouldAnimate: Bool = (UIApplication.shared.applicationState == .active)
SessionApp.presentConversation(for: threadId, animated: shouldAnimate) SessionApp.presentConversation(for: threadId, animated: shouldAnimate)
return Promise.value(()) return Promise.value(())
} }

@ -36,7 +36,7 @@ public final class BackgroundPoller: NSObject {
let poller: OpenGroupAPI.Poller = OpenGroupAPI.Poller(for: server) let poller: OpenGroupAPI.Poller = OpenGroupAPI.Poller(for: server)
poller.stop() poller.stop()
return poller.poll(isBackgroundPoll: true) return poller.poll(isBackgroundPoll: true, isPostCapabilitiesRetry: false)
} }
) )

@ -42,6 +42,7 @@ enum _003_YDBToGRDBMigration: Migration {
var closedGroupModel: [String: SMKLegacy._GroupModel] = [:] var closedGroupModel: [String: SMKLegacy._GroupModel] = [:]
var closedGroupZombieMemberIds: [String: Set<String>] = [:] var closedGroupZombieMemberIds: [String: Set<String>] = [:]
var openGroupServer: [String: String] = [:]
var openGroupInfo: [String: SMKLegacy._OpenGroup] = [:] var openGroupInfo: [String: SMKLegacy._OpenGroup] = [:]
var openGroupUserCount: [String: Int64] = [:] var openGroupUserCount: [String: Int64] = [:]
var openGroupImage: [String: Data] = [:] var openGroupImage: [String: Data] = [:]
@ -171,10 +172,25 @@ enum _003_YDBToGRDBMigration: Migration {
return return
} }
// We want to migrate everyone over to using the domain name for open group
// servers rather than the IP, also best to use HTTPS over HTTP where possible
// so catch the case where we have the domain with HTTP (the 'defaultServer'
// value contains a HTTPS scheme so we get IP HTTP -> HTTPS for free as well)
let processedOpenGroupServer: String = {
// Check if the server is a Session-run one based on it's
guard
openGroup.server.contains(OpenGroupAPI.legacyDefaultServerIP) ||
openGroup.server == OpenGroupAPI.defaultServer
.replacingOccurrences(of: "https://", with: "http://")
else { return openGroup.server }
return OpenGroupAPI.defaultServer
}()
legacyThreadIdToIdMap[thread.uniqueId] = OpenGroup.idFor( legacyThreadIdToIdMap[thread.uniqueId] = OpenGroup.idFor(
roomToken: openGroup.room, roomToken: openGroup.room,
server: openGroup.server server: processedOpenGroupServer
) )
openGroupServer[thread.uniqueId] = processedOpenGroupServer
openGroupInfo[thread.uniqueId] = openGroup openGroupInfo[thread.uniqueId] = openGroup
openGroupUserCount[thread.uniqueId] = ((transaction.object(forKey: openGroup.id, inCollection: SMKLegacy.openGroupUserCountCollection) as? Int64) ?? 0) openGroupUserCount[thread.uniqueId] = ((transaction.object(forKey: openGroup.id, inCollection: SMKLegacy.openGroupUserCountCollection) as? Int64) ?? 0)
openGroupImage[thread.uniqueId] = transaction.object(forKey: openGroup.id, inCollection: SMKLegacy.openGroupImageCollection) as? Data openGroupImage[thread.uniqueId] = transaction.object(forKey: openGroup.id, inCollection: SMKLegacy.openGroupImageCollection) as? Data
@ -641,13 +657,16 @@ enum _003_YDBToGRDBMigration: Migration {
// Open Groups // Open Groups
if legacyThread.isOpenGroup { if legacyThread.isOpenGroup {
guard let openGroup: SMKLegacy._OpenGroup = openGroupInfo[legacyThread.uniqueId] else { guard
let openGroup: SMKLegacy._OpenGroup = openGroupInfo[legacyThread.uniqueId],
let targetOpenGroupServer: String = openGroupServer[legacyThread.uniqueId]
else {
SNLog("[Migration Error] Open group missing required data") SNLog("[Migration Error] Open group missing required data")
throw StorageError.migrationFailed throw StorageError.migrationFailed
} }
try OpenGroup( try OpenGroup(
server: openGroup.server, server: targetOpenGroupServer,
roomToken: openGroup.room, roomToken: openGroup.room,
publicKey: openGroup.publicKey, publicKey: openGroup.publicKey,
isActive: true, isActive: true,

@ -11,8 +11,8 @@ import SessionUtilitiesKit
public enum OpenGroupAPI { public enum OpenGroupAPI {
// MARK: - Settings // MARK: - Settings
public static let legacyDefaultServerDNS = "open.getsession.org" public static let legacyDefaultServerIP = "116.203.70.33"
public static let defaultServer = "http://116.203.70.33" public static let defaultServer = "https://open.getsession.org"
public static let defaultServerPublicKey = "a03c383cf63c3c4efe67acc52112a6dd734b3a946b9545f488aaa93da7991238" public static let defaultServerPublicKey = "a03c383cf63c3c4efe67acc52112a6dd734b3a946b9545f488aaa93da7991238"
public static let workQueue = DispatchQueue(label: "OpenGroupAPI.workQueue", qos: .userInitiated) // It's important that this is a serial queue public static let workQueue = DispatchQueue(label: "OpenGroupAPI.workQueue", qos: .userInitiated) // It's important that this is a serial queue
@ -225,6 +225,7 @@ public enum OpenGroupAPI {
public static func capabilities( public static func capabilities(
_ db: Database, _ db: Database,
server: String, server: String,
authenticated: Bool = true,
using dependencies: SMKDependencies = SMKDependencies() using dependencies: SMKDependencies = SMKDependencies()
) -> Promise<(OnionRequestResponseInfoType, Capabilities)> { ) -> Promise<(OnionRequestResponseInfoType, Capabilities)> {
return OpenGroupAPI return OpenGroupAPI
@ -234,6 +235,7 @@ public enum OpenGroupAPI {
server: server, server: server,
endpoint: .capabilities endpoint: .capabilities
), ),
authenticated: authenticated,
using: dependencies using: dependencies
) )
.decoded(as: Capabilities.self, on: OpenGroupAPI.workQueue, using: dependencies) .decoded(as: Capabilities.self, on: OpenGroupAPI.workQueue, using: dependencies)
@ -394,7 +396,7 @@ public enum OpenGroupAPI {
using dependencies: SMKDependencies = SMKDependencies() using dependencies: SMKDependencies = SMKDependencies()
) -> Promise<(OnionRequestResponseInfoType, Message)> { ) -> Promise<(OnionRequestResponseInfoType, Message)> {
guard let signResult: (publicKey: String, signature: Bytes) = sign(db, messageBytes: plaintext.bytes, for: server, fallbackSigningType: .standard, using: dependencies) else { guard let signResult: (publicKey: String, signature: Bytes) = sign(db, messageBytes: plaintext.bytes, for: server, fallbackSigningType: .standard, using: dependencies) else {
return Promise(error: Error.signingFailed) return Promise(error: OpenGroupAPIError.signingFailed)
} }
return OpenGroupAPI return OpenGroupAPI
@ -450,7 +452,7 @@ public enum OpenGroupAPI {
using dependencies: SMKDependencies = SMKDependencies() using dependencies: SMKDependencies = SMKDependencies()
) -> Promise<(OnionRequestResponseInfoType, Data?)> { ) -> Promise<(OnionRequestResponseInfoType, Data?)> {
guard let signResult: (publicKey: String, signature: Bytes) = sign(db, messageBytes: plaintext.bytes, for: server, fallbackSigningType: .standard, using: dependencies) else { guard let signResult: (publicKey: String, signature: Bytes) = sign(db, messageBytes: plaintext.bytes, for: server, fallbackSigningType: .standard, using: dependencies) else {
return Promise(error: Error.signingFailed) return Promise(error: OpenGroupAPIError.signingFailed)
} }
return OpenGroupAPI return OpenGroupAPI
@ -1223,7 +1225,7 @@ public enum OpenGroupAPI {
.asRequest(of: String.self) .asRequest(of: String.self)
.fetchOne(db) .fetchOne(db)
guard let publicKey: String = maybePublicKey else { return Promise(error: Error.noPublicKey) } guard let publicKey: String = maybePublicKey else { return Promise(error: OpenGroupAPIError.noPublicKey) }
// If we don't want to authenticate the request then send it immediately // If we don't want to authenticate the request then send it immediately
guard authenticated else { guard authenticated else {
@ -1232,7 +1234,7 @@ public enum OpenGroupAPI {
// Attempt to sign the request with the new auth // Attempt to sign the request with the new auth
guard let signedRequest: URLRequest = sign(db, request: urlRequest, for: request.server, with: publicKey, using: dependencies) else { guard let signedRequest: URLRequest = sign(db, request: urlRequest, for: request.server, with: publicKey, using: dependencies) else {
return Promise(error: Error.signingFailed) return Promise(error: OpenGroupAPIError.signingFailed)
} }
return dependencies.onionApi.sendOnionRequest(signedRequest, to: request.server, with: publicKey) return dependencies.onionApi.sendOnionRequest(signedRequest, to: request.server, with: publicKey)

@ -113,7 +113,7 @@ public final class OpenGroupManager: NSObject {
let serverHost: String = (serverUrl.host ?? server.lowercased()) let serverHost: String = (serverUrl.host ?? server.lowercased())
let serverPort: String = (serverUrl.port.map { ":\($0)" } ?? "") let serverPort: String = (serverUrl.port.map { ":\($0)" } ?? "")
let defaultServerHost: String = OpenGroupAPI.defaultServer.substring(from: "http://".count) let defaultServerHost: String = OpenGroupAPI.defaultServer.substring(from: "https://".count)
var serverOptions: Set<String> = Set([ var serverOptions: Set<String> = Set([
server.lowercased(), server.lowercased(),
"\(serverHost)\(serverPort)", "\(serverHost)\(serverPort)",
@ -121,21 +121,15 @@ public final class OpenGroupManager: NSObject {
"https://\(serverHost)\(serverPort)" "https://\(serverHost)\(serverPort)"
]) ])
if serverHost == OpenGroupAPI.legacyDefaultServerDNS { if serverHost == OpenGroupAPI.legacyDefaultServerIP {
let defaultServerOptions: Set<String> = Set([ serverOptions.insert(defaultServerHost)
defaultServerHost, serverOptions.insert("http://\(defaultServerHost)")
OpenGroupAPI.defaultServer, serverOptions.insert(OpenGroupAPI.defaultServer)
"https://\(defaultServerHost)"
])
serverOptions = serverOptions.union(defaultServerOptions)
} }
else if serverHost == defaultServerHost { else if serverHost == defaultServerHost {
let legacyServerOptions: Set<String> = Set([ serverOptions.insert(OpenGroupAPI.legacyDefaultServerIP)
OpenGroupAPI.legacyDefaultServerDNS, serverOptions.insert("http://\(OpenGroupAPI.legacyDefaultServerIP)")
"http://\(OpenGroupAPI.legacyDefaultServerDNS)", serverOptions.insert("https://\(OpenGroupAPI.legacyDefaultServerIP)")
"https://\(OpenGroupAPI.legacyDefaultServerDNS)"
])
serverOptions = serverOptions.union(legacyServerOptions)
} }
// First check if there is no poller for the specified server // First check if there is no poller for the specified server
@ -352,7 +346,7 @@ public final class OpenGroupManager: NSObject {
nil nil
), ),
(openGroup.imageId != pollInfo.details?.imageId.map { "\($0)" } ? (openGroup.imageId != pollInfo.details?.imageId.map { "\($0)" } ?
(pollInfo.details?.imageId).map { OpenGroup.Columns.roomDescription.set(to: "\($0)") } : (pollInfo.details?.imageId).map { OpenGroup.Columns.imageId.set(to: "\($0)") } :
nil nil
), ),
(openGroup.userCount != pollInfo.activeUsers ? (openGroup.userCount != pollInfo.activeUsers ?

@ -0,0 +1,17 @@
// Copyright © 2022 Rangeproof Pty Ltd. All rights reserved.
import Foundation
public enum OpenGroupAPIError: LocalizedError {
case decryptionFailed
case signingFailed
case noPublicKey
public var errorDescription: String? {
switch self {
case .decryptionFailed: return "Couldn't decrypt response."
case .signingFailed: return "Couldn't sign message."
case .noPublicKey: return "Couldn't find server public key."
}
}
}

@ -1,19 +0,0 @@
// Copyright © 2022 Rangeproof Pty Ltd. All rights reserved.
import Foundation
extension OpenGroupAPI {
public enum Error: LocalizedError {
case decryptionFailed
case signingFailed
case noPublicKey
public var errorDescription: String? {
switch self {
case .decryptionFailed: return "Couldn't decrypt response."
case .signingFailed: return "Couldn't sign message."
case .noPublicKey: return "Couldn't find server public key."
}
}
}
}

@ -4,8 +4,12 @@ import Foundation
extension PushNotificationAPI { extension PushNotificationAPI {
struct UpdateRegistrationResponse: Codable { struct UpdateRegistrationResponse: Codable {
let body: String struct Body: Codable {
let code: Int let code: Int
let message: String? let message: String?
}
let status: Int
let body: Body
} }
} }

@ -60,8 +60,8 @@ public final class PushNotificationAPI : NSObject {
guard let response: UpdateRegistrationResponse = try? response?.decoded(as: UpdateRegistrationResponse.self) else { guard let response: UpdateRegistrationResponse = try? response?.decoded(as: UpdateRegistrationResponse.self) else {
return SNLog("Couldn't unregister from push notifications.") return SNLog("Couldn't unregister from push notifications.")
} }
guard response.code != 0 else { guard response.body.code != 0 else {
return SNLog("Couldn't unregister from push notifications due to error: \(response.message ?? "nil").") return SNLog("Couldn't unregister from push notifications due to error: \(response.body.message ?? "nil").")
} }
} }
} }
@ -119,8 +119,8 @@ public final class PushNotificationAPI : NSObject {
guard let response: UpdateRegistrationResponse = try? response?.decoded(as: UpdateRegistrationResponse.self) else { guard let response: UpdateRegistrationResponse = try? response?.decoded(as: UpdateRegistrationResponse.self) else {
return SNLog("Couldn't register device token.") return SNLog("Couldn't register device token.")
} }
guard response.code != 0 else { guard response.body.code != 0 else {
return SNLog("Couldn't register device token due to error: \(response.message ?? "nil").") return SNLog("Couldn't register device token due to error: \(response.body.message ?? "nil").")
} }
userDefaults[.deviceToken] = hexEncodedToken userDefaults[.deviceToken] = hexEncodedToken
@ -180,8 +180,8 @@ public final class PushNotificationAPI : NSObject {
guard let response: UpdateRegistrationResponse = try? response?.decoded(as: UpdateRegistrationResponse.self) else { guard let response: UpdateRegistrationResponse = try? response?.decoded(as: UpdateRegistrationResponse.self) else {
return SNLog("Couldn't subscribe/unsubscribe for closed group: \(closedGroupPublicKey).") return SNLog("Couldn't subscribe/unsubscribe for closed group: \(closedGroupPublicKey).")
} }
guard response.code != 0 else { guard response.body.code != 0 else {
return SNLog("Couldn't subscribe/unsubscribe for closed group: \(closedGroupPublicKey) due to error: \(response.message ?? "nil").") return SNLog("Couldn't subscribe/unsubscribe for closed group: \(closedGroupPublicKey) due to error: \(response.body.message ?? "nil").")
} }
} }
} }

@ -43,11 +43,15 @@ extension OpenGroupAPI {
@discardableResult @discardableResult
public func poll(using dependencies: OpenGroupManager.OGMDependencies = OpenGroupManager.OGMDependencies()) -> Promise<Void> { public func poll(using dependencies: OpenGroupManager.OGMDependencies = OpenGroupManager.OGMDependencies()) -> Promise<Void> {
return poll(isBackgroundPoll: false, using: dependencies) return poll(isBackgroundPoll: false, isPostCapabilitiesRetry: false, using: dependencies)
} }
@discardableResult @discardableResult
public func poll(isBackgroundPoll: Bool, using dependencies: OpenGroupManager.OGMDependencies = OpenGroupManager.OGMDependencies()) -> Promise<Void> { public func poll(
isBackgroundPoll: Bool,
isPostCapabilitiesRetry: Bool,
using dependencies: OpenGroupManager.OGMDependencies = OpenGroupManager.OGMDependencies()
) -> Promise<Void> {
guard !self.isPolling else { return Promise.value(()) } guard !self.isPolling else { return Promise.value(()) }
self.isPolling = true self.isPolling = true
@ -83,15 +87,93 @@ extension OpenGroupAPI {
seal.fulfill(()) seal.fulfill(())
} }
.catch(on: OpenGroupAPI.workQueue) { [weak self] error in .catch(on: OpenGroupAPI.workQueue) { [weak self] error in
SNLog("Open group polling failed due to error: \(error).") // If we are retrying then the error is being handled so no need to continue (this
self?.isPolling = false // method will always resolve)
seal.fulfill(()) // The promise is just used to keep track of when we're done self?.updateCapabilitiesAndRetryIfNeeded(
server: server,
isBackgroundPoll: isBackgroundPoll,
isPostCapabilitiesRetry: isPostCapabilitiesRetry,
error: error
)
.done(on: OpenGroupAPI.workQueue) { [weak self] didHandleError in
if !didHandleError {
SNLog("Open group polling failed due to error: \(error).")
}
self?.isPolling = false
seal.fulfill(()) // The promise is just used to keep track of when we're done
}
.retainUntilComplete()
} }
} }
return promise return promise
} }
private func updateCapabilitiesAndRetryIfNeeded(
server: String,
isBackgroundPoll: Bool,
isPostCapabilitiesRetry: Bool,
error: Error,
using dependencies: OpenGroupManager.OGMDependencies = OpenGroupManager.OGMDependencies()
) -> Promise<Bool> {
/// We want to custom handle a '400' error code due to not having blinded auth as it likely means that we join the
/// OpenGroup before blinding was enabled and need to update it's capabilities
///
/// **Note:** To prevent an infinite loop caused by a server-side bug we want to prevent this capabilities request from
/// happening multiple times in a row
guard
!isPostCapabilitiesRetry,
let error: OnionRequestAPIError = error as? OnionRequestAPIError,
case .httpRequestFailedAtDestination(let statusCode, let data, _) = error,
statusCode == 400,
let dataString: String = String(data: data, encoding: .utf8),
dataString.contains("Invalid authentication: this server requires the use of blinded idse")
else { return Promise.value(false) }
let (promise, seal) = Promise<Bool>.pending()
dependencies.storage
.read { db in
OpenGroupAPI.capabilities(
db,
server: server,
authenticated: false,
using: dependencies
)
}
.then(on: OpenGroupAPI.workQueue) { [weak self] _, responseBody -> Promise<Void> in
guard let strongSelf = self else { return Promise.value(()) }
// Handle the updated capabilities and re-trigger the poll
strongSelf.isPolling = false
dependencies.storage.write { db in
OpenGroupManager.handleCapabilities(
db,
capabilities: responseBody,
on: server
)
}
// Regardless of the outcome we can just resolve this
// immediately as it'll handle it's own response
return strongSelf.poll(
isBackgroundPoll: isBackgroundPoll,
isPostCapabilitiesRetry: true,
using: dependencies
)
.ensure { seal.fulfill(true) }
}
.catch(on: OpenGroupAPI.workQueue) { error in
SNLog("Open group updating capabilities failed due to error: \(error).")
seal.fulfill(true)
}
.retainUntilComplete()
return promise
}
private func handlePollResponse(_ response: [OpenGroupAPI.Endpoint: (info: OnionRequestResponseInfoType, data: Codable?)], isBackgroundPoll: Bool, using dependencies: OpenGroupManager.OGMDependencies = OpenGroupManager.OGMDependencies()) { private func handlePollResponse(_ response: [OpenGroupAPI.Endpoint: (info: OnionRequestResponseInfoType, data: Codable?)], isBackgroundPoll: Bool, using dependencies: OpenGroupManager.OGMDependencies = OpenGroupManager.OGMDependencies()) {
let server: String = self.server let server: String = self.server

@ -715,18 +715,11 @@ public extension MessageViewModel.AttachmentInteractionInfo {
let interactionAttachment: TypedTableAlias<InteractionAttachment> = TypedTableAlias() let interactionAttachment: TypedTableAlias<InteractionAttachment> = TypedTableAlias()
return """ return """
JOIN \(InteractionAttachment.self) ON \(interactionAttachment[.attachmentId]) = \(attachment[.id]) JOIN \(InteractionAttachment.self) ON \(interactionAttachment[.interactionId]) = \(interaction[.id])
JOIN \(Interaction.self) ON JOIN \(Attachment.self) ON \(attachment[.id]) = \(interactionAttachment[.attachmentId])
\(interaction[.id]) = \(interactionAttachment[.interactionId])
""" """
}() }()
static var groupViewModelQuerySQL: SQL = {
let interaction: TypedTableAlias<Interaction> = TypedTableAlias()
return "\(interaction[.id])"
}()
static func createAssociateDataClosure() -> (DataCache<MessageViewModel.AttachmentInteractionInfo>, DataCache<MessageViewModel>) -> DataCache<MessageViewModel> { static func createAssociateDataClosure() -> (DataCache<MessageViewModel.AttachmentInteractionInfo>, DataCache<MessageViewModel>) -> DataCache<MessageViewModel> {
return { dataCache, pagedDataCache -> DataCache<MessageViewModel> in return { dataCache, pagedDataCache -> DataCache<MessageViewModel> in
var updatedPagedDataCache: DataCache<MessageViewModel> = pagedDataCache var updatedPagedDataCache: DataCache<MessageViewModel> = pagedDataCache
@ -786,7 +779,7 @@ public extension MessageViewModel.TypingIndicatorInfo {
let threadTypingIndicator: TypedTableAlias<ThreadTypingIndicator> = TypedTableAlias() let threadTypingIndicator: TypedTableAlias<ThreadTypingIndicator> = TypedTableAlias()
return """ return """
JOIN \(Interaction.self) ON \(interaction[.threadId]) = \(threadTypingIndicator[.threadId]) JOIN \(ThreadTypingIndicator.self) ON \(threadTypingIndicator[.threadId]) = \(interaction[.threadId])
""" """
}() }()

@ -1364,7 +1364,7 @@ class OpenGroupAPISpec: QuickSpec {
expect(error?.localizedDescription) expect(error?.localizedDescription)
.toEventually( .toEventually(
equal(OpenGroupAPI.Error.signingFailed.localizedDescription), equal(OpenGroupAPIError.signingFailed.localizedDescription),
timeout: .milliseconds(100) timeout: .milliseconds(100)
) )
@ -1399,7 +1399,7 @@ class OpenGroupAPISpec: QuickSpec {
expect(error?.localizedDescription) expect(error?.localizedDescription)
.toEventually( .toEventually(
equal(OpenGroupAPI.Error.signingFailed.localizedDescription), equal(OpenGroupAPIError.signingFailed.localizedDescription),
timeout: .milliseconds(100) timeout: .milliseconds(100)
) )
@ -1432,7 +1432,7 @@ class OpenGroupAPISpec: QuickSpec {
expect(error?.localizedDescription) expect(error?.localizedDescription)
.toEventually( .toEventually(
equal(OpenGroupAPI.Error.signingFailed.localizedDescription), equal(OpenGroupAPIError.signingFailed.localizedDescription),
timeout: .milliseconds(100) timeout: .milliseconds(100)
) )
@ -1512,7 +1512,7 @@ class OpenGroupAPISpec: QuickSpec {
expect(error?.localizedDescription) expect(error?.localizedDescription)
.toEventually( .toEventually(
equal(OpenGroupAPI.Error.signingFailed.localizedDescription), equal(OpenGroupAPIError.signingFailed.localizedDescription),
timeout: .milliseconds(100) timeout: .milliseconds(100)
) )
@ -1547,7 +1547,7 @@ class OpenGroupAPISpec: QuickSpec {
expect(error?.localizedDescription) expect(error?.localizedDescription)
.toEventually( .toEventually(
equal(OpenGroupAPI.Error.signingFailed.localizedDescription), equal(OpenGroupAPIError.signingFailed.localizedDescription),
timeout: .milliseconds(100) timeout: .milliseconds(100)
) )
@ -1588,7 +1588,7 @@ class OpenGroupAPISpec: QuickSpec {
expect(error?.localizedDescription) expect(error?.localizedDescription)
.toEventually( .toEventually(
equal(OpenGroupAPI.Error.signingFailed.localizedDescription), equal(OpenGroupAPIError.signingFailed.localizedDescription),
timeout: .milliseconds(100) timeout: .milliseconds(100)
) )
@ -1772,7 +1772,7 @@ class OpenGroupAPISpec: QuickSpec {
expect(error?.localizedDescription) expect(error?.localizedDescription)
.toEventually( .toEventually(
equal(OpenGroupAPI.Error.signingFailed.localizedDescription), equal(OpenGroupAPIError.signingFailed.localizedDescription),
timeout: .milliseconds(100) timeout: .milliseconds(100)
) )
@ -1806,7 +1806,7 @@ class OpenGroupAPISpec: QuickSpec {
expect(error?.localizedDescription) expect(error?.localizedDescription)
.toEventually( .toEventually(
equal(OpenGroupAPI.Error.signingFailed.localizedDescription), equal(OpenGroupAPIError.signingFailed.localizedDescription),
timeout: .milliseconds(100) timeout: .milliseconds(100)
) )
@ -1838,7 +1838,7 @@ class OpenGroupAPISpec: QuickSpec {
expect(error?.localizedDescription) expect(error?.localizedDescription)
.toEventually( .toEventually(
equal(OpenGroupAPI.Error.signingFailed.localizedDescription), equal(OpenGroupAPIError.signingFailed.localizedDescription),
timeout: .milliseconds(100) timeout: .milliseconds(100)
) )
@ -1916,7 +1916,7 @@ class OpenGroupAPISpec: QuickSpec {
expect(error?.localizedDescription) expect(error?.localizedDescription)
.toEventually( .toEventually(
equal(OpenGroupAPI.Error.signingFailed.localizedDescription), equal(OpenGroupAPIError.signingFailed.localizedDescription),
timeout: .milliseconds(100) timeout: .milliseconds(100)
) )
@ -1950,7 +1950,7 @@ class OpenGroupAPISpec: QuickSpec {
expect(error?.localizedDescription) expect(error?.localizedDescription)
.toEventually( .toEventually(
equal(OpenGroupAPI.Error.signingFailed.localizedDescription), equal(OpenGroupAPIError.signingFailed.localizedDescription),
timeout: .milliseconds(100) timeout: .milliseconds(100)
) )
@ -1990,7 +1990,7 @@ class OpenGroupAPISpec: QuickSpec {
expect(error?.localizedDescription) expect(error?.localizedDescription)
.toEventually( .toEventually(
equal(OpenGroupAPI.Error.signingFailed.localizedDescription), equal(OpenGroupAPIError.signingFailed.localizedDescription),
timeout: .milliseconds(100) timeout: .milliseconds(100)
) )
@ -2915,7 +2915,7 @@ class OpenGroupAPISpec: QuickSpec {
expect(error?.localizedDescription) expect(error?.localizedDescription)
.toEventually( .toEventually(
equal(OpenGroupAPI.Error.signingFailed.localizedDescription), equal(OpenGroupAPIError.signingFailed.localizedDescription),
timeout: .milliseconds(100) timeout: .milliseconds(100)
) )
@ -2942,7 +2942,7 @@ class OpenGroupAPISpec: QuickSpec {
expect(error?.localizedDescription) expect(error?.localizedDescription)
.toEventually( .toEventually(
equal(OpenGroupAPI.Error.noPublicKey.localizedDescription), equal(OpenGroupAPIError.noPublicKey.localizedDescription),
timeout: .milliseconds(100) timeout: .milliseconds(100)
) )
@ -2969,7 +2969,7 @@ class OpenGroupAPISpec: QuickSpec {
expect(error?.localizedDescription) expect(error?.localizedDescription)
.toEventually( .toEventually(
equal(OpenGroupAPI.Error.signingFailed.localizedDescription), equal(OpenGroupAPIError.signingFailed.localizedDescription),
timeout: .milliseconds(100) timeout: .milliseconds(100)
) )
@ -3037,7 +3037,7 @@ class OpenGroupAPISpec: QuickSpec {
expect(error?.localizedDescription) expect(error?.localizedDescription)
.toEventually( .toEventually(
equal(OpenGroupAPI.Error.signingFailed.localizedDescription), equal(OpenGroupAPIError.signingFailed.localizedDescription),
timeout: .milliseconds(100) timeout: .milliseconds(100)
) )
@ -3108,7 +3108,7 @@ class OpenGroupAPISpec: QuickSpec {
expect(error?.localizedDescription) expect(error?.localizedDescription)
.toEventually( .toEventually(
equal(OpenGroupAPI.Error.signingFailed.localizedDescription), equal(OpenGroupAPIError.signingFailed.localizedDescription),
timeout: .milliseconds(100) timeout: .milliseconds(100)
) )
@ -3135,7 +3135,7 @@ class OpenGroupAPISpec: QuickSpec {
expect(error?.localizedDescription) expect(error?.localizedDescription)
.toEventually( .toEventually(
equal(OpenGroupAPI.Error.signingFailed.localizedDescription), equal(OpenGroupAPIError.signingFailed.localizedDescription),
timeout: .milliseconds(100) timeout: .milliseconds(100)
) )

@ -1578,6 +1578,7 @@ class OpenGroupManagerSpec: QuickSpec {
isActive: true, isActive: true,
name: "Test", name: "Test",
imageId: "12", imageId: "12",
imageData: Data([1, 2, 3]),
userCount: 0, userCount: 0,
infoUpdates: 10 infoUpdates: 10
).insert(db) ).insert(db)
@ -3004,7 +3005,7 @@ class OpenGroupManagerSpec: QuickSpec {
.asRequest(of: String.self) .asRequest(of: String.self)
.fetchOne(db) .fetchOne(db)
} }
).to(equal("http://116.203.70.33")) ).to(equal("https://open.getsession.org"))
expect( expect(
mockStorage.read { db in mockStorage.read { db in
try OpenGroup try OpenGroup

@ -13,11 +13,11 @@ class SOGSErrorSpec: QuickSpec {
override func spec() { override func spec() {
describe("a SOGSError") { describe("a SOGSError") {
it("generates the error description correctly") { it("generates the error description correctly") {
expect(OpenGroupAPI.Error.decryptionFailed.errorDescription) expect(OpenGroupAPIError.decryptionFailed.errorDescription)
.to(equal("Couldn't decrypt response.")) .to(equal("Couldn't decrypt response."))
expect(OpenGroupAPI.Error.signingFailed.errorDescription) expect(OpenGroupAPIError.signingFailed.errorDescription)
.to(equal("Couldn't sign message.")) .to(equal("Couldn't sign message."))
expect(OpenGroupAPI.Error.noPublicKey.errorDescription) expect(OpenGroupAPIError.noPublicKey.errorDescription)
.to(equal("Couldn't find server public key.")) .to(equal("Couldn't find server public key."))
} }
} }

@ -299,8 +299,12 @@ public final class NotificationServiceExtension: UNNotificationServiceExtension
private func pollForOpenGroups() -> [Promise<Void>] { private func pollForOpenGroups() -> [Promise<Void>] {
let promises: [Promise<Void>] = GRDBStorage.shared let promises: [Promise<Void>] = GRDBStorage.shared
.read { db in .read { db in
// The default room promise creates an OpenGroup with an empty `roomToken` value,
// we don't want to start a poller for this as the user hasn't actually joined a room
try OpenGroup try OpenGroup
.select(.server) .select(.server)
.filter(OpenGroup.Columns.roomToken != "")
.filter(OpenGroup.Columns.isActive)
.distinct() .distinct()
.asRequest(of: String.self) .asRequest(of: String.self)
.fetchSet(db) .fetchSet(db)
@ -308,7 +312,7 @@ public final class NotificationServiceExtension: UNNotificationServiceExtension
.defaulting(to: []) .defaulting(to: [])
.map { server in .map { server in
OpenGroupAPI.Poller(for: server) OpenGroupAPI.Poller(for: server)
.poll(isBackgroundPoll: true) .poll(isBackgroundPoll: true, isPostCapabilitiesRetry: false)
.timeout( .timeout(
seconds: 20, seconds: 20,
timeoutError: NotificationServiceError.timeout timeoutError: NotificationServiceError.timeout

@ -60,6 +60,7 @@ public class PagedDatabaseObserver<ObservedTable, T>: TransactionObserver where
self.orderSQL = orderSQL self.orderSQL = orderSQL
self.dataQuery = dataQuery self.dataQuery = dataQuery
self.associatedRecords = associatedRecords self.associatedRecords = associatedRecords
.map { $0.settingPagedTableName(pagedTableName: pagedTable.databaseTableName) }
self.onChangeUnsorted = onChangeUnsorted self.onChangeUnsorted = onChangeUnsorted
// Combine the various observed changes into a single set // Combine the various observed changes into a single set
@ -141,6 +142,7 @@ public class PagedDatabaseObserver<ObservedTable, T>: TransactionObserver where
let hasChanges: Bool = associatedRecord.tryUpdateForDatabaseCommit( let hasChanges: Bool = associatedRecord.tryUpdateForDatabaseCommit(
db, db,
changes: committedChanges, changes: committedChanges,
joinSQL: joinSQL,
orderSQL: orderSQL, orderSQL: orderSQL,
filterSQL: filterSQL, filterSQL: filterSQL,
pageInfo: updatedPageInfo pageInfo: updatedPageInfo
@ -616,13 +618,15 @@ public protocol FetchableRecordWithRowId: FetchableRecord {
public protocol ErasedAssociatedRecord { public protocol ErasedAssociatedRecord {
var databaseTableName: String { get } var databaseTableName: String { get }
var pagedTableName: String { get }
var observedChanges: [PagedData.ObservedChanges] { get } var observedChanges: [PagedData.ObservedChanges] { get }
var joinToPagedType: SQL { get } var joinToPagedType: SQL { get }
var groupPagedType: SQL? { get }
func settingPagedTableName(pagedTableName: String) -> Self
func tryUpdateForDatabaseCommit( func tryUpdateForDatabaseCommit(
_ db: Database, _ db: Database,
changes: Set<PagedData.TrackedChange>, changes: Set<PagedData.TrackedChange>,
joinSQL: SQL?,
orderSQL: SQL, orderSQL: SQL,
filterSQL: SQL, filterSQL: SQL,
pageInfo: PagedData.PageInfo pageInfo: PagedData.PageInfo
@ -886,45 +890,11 @@ public enum PagedData {
tableName: String, tableName: String,
requiredJoinSQL: SQL? = nil, requiredJoinSQL: SQL? = nil,
orderSQL: SQL, orderSQL: SQL,
filterSQL: SQL, filterSQL: SQL
joinToPagedType: SQL? = nil,
groupPagedType: SQL? = nil
) -> [Int64] { ) -> [Int64] {
guard !rowIds.isEmpty else { return [] } guard !rowIds.isEmpty else { return [] }
let tableNameLiteral: SQL = SQL(stringLiteral: tableName) let tableNameLiteral: SQL = SQL(stringLiteral: tableName)
/// **Note:** `ROW_NUMBER` works by returning the index of the row in a given query, unfortunately when dealing
/// with associated data its possible for multiple results to connect to an individual paged result, this throws off the
/// indexes so in this case we need to do some sneaky aggregation and grouping and then individually retrieve each
/// index to prevent this
guard joinToPagedType == nil || rowIds.count == 1 else {
guard let groupPagedType: SQL = groupPagedType else { return [] }
let groupByLiteral: SQL = SQL(stringLiteral: "GROUP BY ")
return rowIds.compactMap { rowId in
let groupedRequest: SQLRequest<Int64> = """
SELECT
(data.rowIndex - 1) AS rowIndex -- Converting from 1-Indexed to 0-indexed
FROM (
SELECT
\(tableNameLiteral).rowid AS rowid,
\(SQL("MAX(\(tableNameLiteral).rowid = \(rowId))")),
ROW_NUMBER() OVER (ORDER BY \(orderSQL)) AS rowIndex
FROM \(tableNameLiteral)
\(requiredJoinSQL ?? "")
\(joinToPagedType ?? "")
WHERE \(filterSQL)
\(groupByLiteral)\(groupPagedType)
) AS data
WHERE \(SQL("data.rowid = \(rowId)"))
"""
return try? groupedRequest.fetchOne(db)
}
}
let request: SQLRequest<Int64> = """ let request: SQLRequest<Int64> = """
SELECT SELECT
(data.rowIndex - 1) AS rowIndex -- Converting from 1-Indexed to 0-indexed (data.rowIndex - 1) AS rowIndex -- Converting from 1-Indexed to 0-indexed
@ -934,7 +904,6 @@ public enum PagedData {
ROW_NUMBER() OVER (ORDER BY \(orderSQL)) AS rowIndex ROW_NUMBER() OVER (ORDER BY \(orderSQL)) AS rowIndex
FROM \(tableNameLiteral) FROM \(tableNameLiteral)
\(requiredJoinSQL ?? "") \(requiredJoinSQL ?? "")
\(joinToPagedType ?? "")
WHERE \(filterSQL) WHERE \(filterSQL)
) AS data ) AS data
WHERE \(SQL("data.rowid IN \(rowIds)")) WHERE \(SQL("data.rowid IN \(rowIds)"))
@ -958,7 +927,7 @@ public enum PagedData {
let pagedTableNameLiteral: SQL = SQL(stringLiteral: pagedTableName) let pagedTableNameLiteral: SQL = SQL(stringLiteral: pagedTableName)
let request: SQLRequest<Int64> = """ let request: SQLRequest<Int64> = """
SELECT \(tableNameLiteral).rowid AS rowid SELECT \(tableNameLiteral).rowid AS rowid
FROM \(tableNameLiteral) FROM \(pagedTableNameLiteral)
\(joinToPagedType) \(joinToPagedType)
WHERE \(pagedTableNameLiteral).rowId IN \(pagedTypeRowIds) WHERE \(pagedTableNameLiteral).rowId IN \(pagedTypeRowIds)
""" """
@ -995,9 +964,9 @@ public enum PagedData {
public class AssociatedRecord<T, PagedType>: ErasedAssociatedRecord where T: FetchableRecordWithRowId & Identifiable, PagedType: FetchableRecordWithRowId & Identifiable { public class AssociatedRecord<T, PagedType>: ErasedAssociatedRecord where T: FetchableRecordWithRowId & Identifiable, PagedType: FetchableRecordWithRowId & Identifiable {
public let databaseTableName: String public let databaseTableName: String
public private(set) var pagedTableName: String = ""
public let observedChanges: [PagedData.ObservedChanges] public let observedChanges: [PagedData.ObservedChanges]
public let joinToPagedType: SQL public let joinToPagedType: SQL
public let groupPagedType: SQL?
fileprivate let dataCache: Atomic<DataCache<T>> = Atomic(DataCache()) fileprivate let dataCache: Atomic<DataCache<T>> = Atomic(DataCache())
fileprivate let dataQuery: (SQL?) -> AdaptedFetchRequest<SQLRequest<T>> fileprivate let dataQuery: (SQL?) -> AdaptedFetchRequest<SQLRequest<T>>
@ -1010,14 +979,12 @@ public class AssociatedRecord<T, PagedType>: ErasedAssociatedRecord where T: Fet
observedChanges: [PagedData.ObservedChanges], observedChanges: [PagedData.ObservedChanges],
dataQuery: @escaping (SQL?) -> AdaptedFetchRequest<SQLRequest<T>>, dataQuery: @escaping (SQL?) -> AdaptedFetchRequest<SQLRequest<T>>,
joinToPagedType: SQL, joinToPagedType: SQL,
groupPagedType: SQL? = nil,
associateData: @escaping (DataCache<T>, DataCache<PagedType>) -> DataCache<PagedType> associateData: @escaping (DataCache<T>, DataCache<PagedType>) -> DataCache<PagedType>
) { ) {
self.databaseTableName = trackedAgainst.databaseTableName self.databaseTableName = trackedAgainst.databaseTableName
self.observedChanges = observedChanges self.observedChanges = observedChanges
self.dataQuery = dataQuery self.dataQuery = dataQuery
self.joinToPagedType = joinToPagedType self.joinToPagedType = joinToPagedType
self.groupPagedType = groupPagedType
self.associateData = associateData self.associateData = associateData
} }
@ -1026,7 +993,6 @@ public class AssociatedRecord<T, PagedType>: ErasedAssociatedRecord where T: Fet
observedChanges: [PagedData.ObservedChanges], observedChanges: [PagedData.ObservedChanges],
dataQuery: @escaping (SQL?) -> SQLRequest<T>, dataQuery: @escaping (SQL?) -> SQLRequest<T>,
joinToPagedType: SQL, joinToPagedType: SQL,
groupPagedType: SQL? = nil,
associateData: @escaping (DataCache<T>, DataCache<PagedType>) -> DataCache<PagedType> associateData: @escaping (DataCache<T>, DataCache<PagedType>) -> DataCache<PagedType>
) { ) {
self.init( self.init(
@ -1036,16 +1002,21 @@ public class AssociatedRecord<T, PagedType>: ErasedAssociatedRecord where T: Fet
dataQuery(additionalFilters).adapted { _ in ScopeAdapter([:]) } dataQuery(additionalFilters).adapted { _ in ScopeAdapter([:]) }
}, },
joinToPagedType: joinToPagedType, joinToPagedType: joinToPagedType,
groupPagedType: groupPagedType,
associateData: associateData associateData: associateData
) )
} }
// MARK: - AssociatedRecord // MARK: - AssociatedRecord
public func settingPagedTableName(pagedTableName: String) -> Self {
self.pagedTableName = pagedTableName
return self
}
public func tryUpdateForDatabaseCommit( public func tryUpdateForDatabaseCommit(
_ db: Database, _ db: Database,
changes: Set<PagedData.TrackedChange>, changes: Set<PagedData.TrackedChange>,
joinSQL: SQL?,
orderSQL: SQL, orderSQL: SQL,
filterSQL: SQL, filterSQL: SQL,
pageInfo: PagedData.PageInfo pageInfo: PagedData.PageInfo
@ -1075,44 +1046,52 @@ public class AssociatedRecord<T, PagedType>: ErasedAssociatedRecord where T: Fet
guard !rowIdsToQuery.isEmpty else { return (oldCount != countAfterDeletions) } guard !rowIdsToQuery.isEmpty else { return (oldCount != countAfterDeletions) }
// Fetch the indexes of the rowIds so we can determine whether they should be added to the screen // Fetch the indexes of the rowIds so we can determine whether they should be added to the screen
let itemIndexes: [Int64] = PagedData.indexes( let pagedRowIds: [Int64] = PagedData.pagedRowIdsForRelatedRowIds(
db, db,
rowIds: rowIdsToQuery,
tableName: databaseTableName, tableName: databaseTableName,
pagedTableName: pagedTableName,
relatedRowIds: rowIdsToQuery,
joinToPagedType: joinToPagedType
)
// If the associated data change isn't related to the paged type then no need to continue
guard !pagedRowIds.isEmpty else { return (oldCount != countAfterDeletions) }
let pagedItemIndexes: [Int64] = PagedData.indexes(
db,
rowIds: pagedRowIds,
tableName: pagedTableName,
requiredJoinSQL: joinSQL,
orderSQL: orderSQL, orderSQL: orderSQL,
filterSQL: filterSQL, filterSQL: filterSQL
joinToPagedType: joinToPagedType,
groupPagedType: groupPagedType
) )
// Determine if the indexes for the row ids should be displayed on the screen and remove any // If we can't get the item indexes for the paged row ids then it's likely related to data
// which shouldn't - values less than 'currentCount' or if there is at least one value less than // which was filtered out (eg. message attachment related to a different thread)
// 'currentCount' and the indexes are sequential (ie. more than the current loaded content was guard !pagedItemIndexes.isEmpty else { return (oldCount != countAfterDeletions) }
// added at once)
let uniqueIndexes: [Int64] = itemIndexes.asSet().sorted() /// **Note:** The `PagedData.indexes` works by returning the index of a row in a given query, unfortunately when
let itemIndexesAreSequential: Bool = (uniqueIndexes.map { $0 - 1 }.dropFirst() == uniqueIndexes.dropLast()) /// dealing with associated data its possible for multiple associated data values to connect to an individual paged result,
let hasOneValidIndex: Bool = itemIndexes.contains(where: { index -> Bool in /// this throws off the indexes so we can't actually tell what `rowIdsToQuery` value is associated to which
/// `pagedItemIndexes` value
///
/// Instead of following the pattern the `PagedDatabaseObserver` does where we get the proper `validRowIds` we
/// basically have to check if there is a single valid index, and if so retrieve and store all data related to the changes for this
/// commit - this will mean in some cases we cache data which is actually unrelated to the filtered paged data
let hasOneValidIndex: Bool = pagedItemIndexes.contains(where: { index -> Bool in
index >= pageInfo.pageOffset && ( index >= pageInfo.pageOffset && (
index < pageInfo.currentCount || index < pageInfo.currentCount ||
pageInfo.currentCount == 0 pageInfo.currentCount == 0
) )
}) })
let validRowIds: [Int64] = (itemIndexesAreSequential && hasOneValidIndex ?
rowIdsToQuery : // Don't bother continuing if we don't have a valid index
zip(itemIndexes, rowIdsToQuery) guard hasOneValidIndex else { return (oldCount != countAfterDeletions) }
.filter { index, _ -> Bool in
index >= pageInfo.pageOffset && (
index < pageInfo.currentCount ||
pageInfo.currentCount == 0
)
}
.map { _, rowId -> Int64 in rowId }
)
// Attempt to update the cache with the `validRowIds` array // Attempt to update the cache with the `validRowIds` array
return updateCache( return updateCache(
db, db,
rowIds: validRowIds, rowIds: rowIdsToQuery,
hasOtherChanges: (oldCount != countAfterDeletions) hasOtherChanges: (oldCount != countAfterDeletions)
) )
} }

Loading…
Cancel
Save