From b621bdc69a573f52fb26083826fe89e288f3424f Mon Sep 17 00:00:00 2001 From: Niels Andriesse Date: Mon, 30 Nov 2020 11:00:28 +1100 Subject: [PATCH] Partially fix open groups --- .../Database/Storage+Jobs.swift | 2 +- .../Database/Storage+Messaging.swift | 12 ++++++---- .../Database/Storage+OpenGroups.swift | 12 ++++++++++ .../Jobs/MessageReceiveJob.swift | 24 ++++++++++++------- .../Signal/TSIncomingMessage+Conversion.swift | 8 ++++--- .../MessageReceiver+Handling.swift | 10 ++++---- .../Sending & Receiving/MessageReceiver.swift | 24 ++++++++++++------- .../Sending & Receiving/MessageSender.swift | 7 ++++++ .../Pollers/OpenGroupPoller.swift | 4 ++-- SessionMessagingKit/Storage.swift | 5 ++-- 10 files changed, 75 insertions(+), 33 deletions(-) diff --git a/SessionMessagingKit/Database/Storage+Jobs.swift b/SessionMessagingKit/Database/Storage+Jobs.swift index 8b0bf3315..8e9b5ab95 100644 --- a/SessionMessagingKit/Database/Storage+Jobs.swift +++ b/SessionMessagingKit/Database/Storage+Jobs.swift @@ -16,7 +16,7 @@ extension Storage { public func getAllPendingJobs(of type: Job.Type) -> [Job] { var result: [Job] = [] Storage.read { transaction in - transaction.enumerateRows(inCollection: type.collection) { _, object, _, _ in + transaction.enumerateRows(inCollection: type.collection) { key, object, _, x in guard let job = object as? Job else { return } result.append(job) } diff --git a/SessionMessagingKit/Database/Storage+Messaging.swift b/SessionMessagingKit/Database/Storage+Messaging.swift index 9783b72e1..e725a0c4a 100644 --- a/SessionMessagingKit/Database/Storage+Messaging.swift +++ b/SessionMessagingKit/Database/Storage+Messaging.swift @@ -17,10 +17,14 @@ extension Storage { } /// Returns the ID of the thread. - public func getOrCreateThread(for publicKey: String, groupPublicKey: String?, using transaction: Any) -> String? { + public func getOrCreateThread(for publicKey: String, groupPublicKey: String?, openGroupID: String?, using transaction: Any) -> String? { let transaction = transaction as! YapDatabaseReadWriteTransaction var threadOrNil: TSThread? - if let groupPublicKey = groupPublicKey { + if let openGroupID = openGroupID { + if let threadID = Storage.shared.getThreadID(for: openGroupID), let thread = TSGroupThread.fetch(uniqueId: threadID, transaction: transaction) { + threadOrNil = thread + } + } else if let groupPublicKey = groupPublicKey { guard Storage.shared.isClosedGroup(groupPublicKey) else { return nil } let groupID = LKGroupUtilities.getEncodedClosedGroupIDAsData(groupPublicKey) threadOrNil = TSGroupThread.fetch(uniqueId: TSGroupThread.threadId(fromGroupId: groupID), transaction: transaction) @@ -31,9 +35,9 @@ extension Storage { } /// Returns the ID of the `TSIncomingMessage` that was constructed. - public func persist(_ message: VisibleMessage, quotedMessage: TSQuotedMessage?, linkPreview: OWSLinkPreview?, groupPublicKey: String?, using transaction: Any) -> String? { + public func persist(_ message: VisibleMessage, quotedMessage: TSQuotedMessage?, linkPreview: OWSLinkPreview?, groupPublicKey: String?, openGroupID: String?, using transaction: Any) -> String? { let transaction = transaction as! YapDatabaseReadWriteTransaction - guard let threadID = getOrCreateThread(for: message.sender!, groupPublicKey: groupPublicKey, using: transaction), + guard let threadID = getOrCreateThread(for: message.sender!, groupPublicKey: groupPublicKey, openGroupID: openGroupID, using: transaction), let thread = TSThread.fetch(uniqueId: threadID, transaction: transaction) else { return nil } let message = TSIncomingMessage.from(message, quotedMessage: quotedMessage, linkPreview: linkPreview, associatedWith: thread) message.save(with: transaction) diff --git a/SessionMessagingKit/Database/Storage+OpenGroups.swift b/SessionMessagingKit/Database/Storage+OpenGroups.swift index 8fc9ebc94..ed3fcd90a 100644 --- a/SessionMessagingKit/Database/Storage+OpenGroups.swift +++ b/SessionMessagingKit/Database/Storage+OpenGroups.swift @@ -24,6 +24,18 @@ extension Storage { } return result } + + public func getThreadID(for openGroupID: String) -> String? { + var result: String? + Storage.read { transaction in + transaction.enumerateKeysAndObjects(inCollection: Storage.openGroupCollection, using: { threadID, object, stop in + guard let openGroup = object as? OpenGroup, "\(openGroup.server).\(openGroup.channel)" == openGroupID else { return } + result = threadID + stop.pointee = true + }) + } + return result + } @objc(setOpenGroup:forThreadWithID:using:) public func setOpenGroup(_ openGroup: OpenGroup, for threadID: String, using transaction: Any) { diff --git a/SessionMessagingKit/Jobs/MessageReceiveJob.swift b/SessionMessagingKit/Jobs/MessageReceiveJob.swift index b66a039f8..8638cb539 100644 --- a/SessionMessagingKit/Jobs/MessageReceiveJob.swift +++ b/SessionMessagingKit/Jobs/MessageReceiveJob.swift @@ -2,7 +2,8 @@ import SessionUtilitiesKit public final class MessageReceiveJob : NSObject, Job, NSCoding { // NSObject/NSCoding conformance is needed for YapDatabase compatibility public let data: Data - public let messageServerID: UInt64? + public let openGroupMessageServerID: UInt64? + public let openGroupID: String? public var delegate: JobDelegate? public var id: String? public var failureCount: UInt = 0 @@ -12,9 +13,14 @@ public final class MessageReceiveJob : NSObject, Job, NSCoding { // NSObject/NSC public static let maxFailureCount: UInt = 10 // MARK: Initialization - public init(data: Data, messageServerID: UInt64? = nil) { + public init(data: Data, openGroupMessageServerID: UInt64? = nil, openGroupID: String? = nil) { self.data = data - self.messageServerID = messageServerID + self.openGroupMessageServerID = openGroupMessageServerID + self.openGroupID = openGroupID + #if DEBUG + if openGroupMessageServerID != nil { assert(openGroupID != nil) } + if openGroupID != nil { assert(openGroupMessageServerID != nil) } + #endif } // MARK: Coding @@ -22,14 +28,16 @@ public final class MessageReceiveJob : NSObject, Job, NSCoding { // NSObject/NSC guard let data = coder.decodeObject(forKey: "data") as! Data?, let id = coder.decodeObject(forKey: "id") as! String? else { return nil } self.data = data - self.messageServerID = coder.decodeObject(forKey: "messageServerUD") as! UInt64? + self.openGroupMessageServerID = coder.decodeObject(forKey: "openGroupMessageServerID") as! UInt64? + self.openGroupID = coder.decodeObject(forKey: "openGroupID") as! String? self.id = id self.failureCount = coder.decodeObject(forKey: "failureCount") as! UInt? ?? 0 } public func encode(with coder: NSCoder) { coder.encode(data, forKey: "data") - coder.encode(messageServerID, forKey: "messageServerID") + coder.encode(openGroupMessageServerID, forKey: "openGroupMessageServerID") + coder.encode(openGroupID, forKey: "openGroupID") coder.encode(id, forKey: "id") coder.encode(failureCount, forKey: "failureCount") } @@ -38,11 +46,11 @@ public final class MessageReceiveJob : NSObject, Job, NSCoding { // NSObject/NSC public func execute() { Configuration.shared.storage.withAsync({ transaction in // Intentionally capture self do { - let (message, proto) = try MessageReceiver.parse(self.data, messageServerID: self.messageServerID, using: transaction) - try MessageReceiver.handle(message, associatedWithProto: proto, using: transaction) + let (message, proto) = try MessageReceiver.parse(self.data, openGroupMessageServerID: self.openGroupMessageServerID, using: transaction) + try MessageReceiver.handle(message, associatedWithProto: proto, openGroupID: self.openGroupID, using: transaction) self.handleSuccess() } catch { - SNLog("Couldn't parse message due to error: \(error).") + SNLog("Couldn't receive message due to error: \(error).") if let error = error as? MessageReceiver.Error, !error.isRetryable { self.handlePermanentFailure(error: error) } else { diff --git a/SessionMessagingKit/Messages/Signal/TSIncomingMessage+Conversion.swift b/SessionMessagingKit/Messages/Signal/TSIncomingMessage+Conversion.swift index c97419fcb..d48bc62ae 100644 --- a/SessionMessagingKit/Messages/Signal/TSIncomingMessage+Conversion.swift +++ b/SessionMessagingKit/Messages/Signal/TSIncomingMessage+Conversion.swift @@ -7,6 +7,8 @@ public extension TSIncomingMessage { Storage.read { transaction in expiration = thread.disappearingMessagesDuration(with: transaction) } + let openGroupServerMessageID = visibleMessage.openGroupServerMessageID ?? 0 + let isOpenGroupMessage = (openGroupServerMessageID != 0) let result = TSIncomingMessage( timestamp: visibleMessage.sentTimestamp!, in: thread, @@ -14,14 +16,14 @@ public extension TSIncomingMessage { sourceDeviceId: 1, messageBody: visibleMessage.text, attachmentIds: visibleMessage.attachmentIDs, - expiresInSeconds: expiration, + expiresInSeconds: !isOpenGroupMessage ? expiration : 0, // Ensure we don't ever expire open group messages quotedMessage: quotedMessage, linkPreview: linkPreview, serverTimestamp: nil, wasReceivedByUD: true ) - result.openGroupServerMessageID = visibleMessage.openGroupServerMessageID ?? 0 - result.isOpenGroupMessage = result.openGroupServerMessageID != 0 + result.openGroupServerMessageID = openGroupServerMessageID + result.isOpenGroupMessage = isOpenGroupMessage return result } } diff --git a/SessionMessagingKit/Sending & Receiving/MessageReceiver+Handling.swift b/SessionMessagingKit/Sending & Receiving/MessageReceiver+Handling.swift index a1fb885c0..0b3b2d8cd 100644 --- a/SessionMessagingKit/Sending & Receiving/MessageReceiver+Handling.swift +++ b/SessionMessagingKit/Sending & Receiving/MessageReceiver+Handling.swift @@ -7,13 +7,13 @@ extension MessageReceiver { return SSKEnvironment.shared.blockingManager.isRecipientIdBlocked(publicKey) } - internal static func handle(_ message: Message, associatedWithProto proto: SNProtoContent, using transaction: Any) throws { + internal static func handle(_ message: Message, associatedWithProto proto: SNProtoContent, openGroupID: String?, using transaction: Any) throws { switch message { case let message as ReadReceipt: handleReadReceipt(message, using: transaction) case let message as TypingIndicator: handleTypingIndicator(message, using: transaction) case let message as ClosedGroupUpdate: handleClosedGroupUpdate(message, using: transaction) case let message as ExpirationTimerUpdate: handleExpirationTimerUpdate(message, using: transaction) - case let message as VisibleMessage: try handleVisibleMessage(message, associatedWithProto: proto, using: transaction) + case let message as VisibleMessage: try handleVisibleMessage(message, associatedWithProto: proto, openGroupID: openGroupID, using: transaction) default: fatalError() } } @@ -135,7 +135,7 @@ extension MessageReceiver { SSKEnvironment.shared.disappearingMessagesJob.startIfNecessary() } - private static func handleVisibleMessage(_ message: VisibleMessage, associatedWithProto proto: SNProtoContent, using transaction: Any) throws { + private static func handleVisibleMessage(_ message: VisibleMessage, associatedWithProto proto: SNProtoContent, openGroupID: String?, using transaction: Any) throws { let storage = Configuration.shared.storage let transaction = transaction as! YapDatabaseReadWriteTransaction // Parse & persist attachments @@ -159,7 +159,7 @@ extension MessageReceiver { } } // Get or create thread - guard let threadID = storage.getOrCreateThread(for: message.sender!, groupPublicKey: message.groupPublicKey, using: transaction) else { throw Error.noThread } + guard let threadID = storage.getOrCreateThread(for: message.sender!, groupPublicKey: message.groupPublicKey, openGroupID: openGroupID, using: transaction) else { throw Error.noThread } // Parse quote if needed var tsQuotedMessage: TSQuotedMessage? = nil if message.quote != nil && proto.dataMessage?.quote != nil, let thread = TSThread.fetch(uniqueId: threadID, transaction: transaction) { @@ -178,7 +178,7 @@ extension MessageReceiver { } // Persist the message guard let tsIncomingMessageID = storage.persist(message, quotedMessage: tsQuotedMessage, linkPreview: owsLinkPreview, - groupPublicKey: message.groupPublicKey, using: transaction) else { throw Error.noThread } + groupPublicKey: message.groupPublicKey, openGroupID: openGroupID, using: transaction) else { throw Error.noThread } message.threadID = threadID // Start attachment downloads if needed storage.withAsync({ transaction in diff --git a/SessionMessagingKit/Sending & Receiving/MessageReceiver.swift b/SessionMessagingKit/Sending & Receiving/MessageReceiver.swift index a42a60bae..3cd16bf29 100644 --- a/SessionMessagingKit/Sending & Receiving/MessageReceiver.swift +++ b/SessionMessagingKit/Sending & Receiving/MessageReceiver.swift @@ -43,8 +43,9 @@ internal enum MessageReceiver { } } - internal static func parse(_ data: Data, messageServerID: UInt64?, using transaction: Any) throws -> (Message, SNProtoContent) { + internal static func parse(_ data: Data, openGroupMessageServerID: UInt64?, using transaction: Any) throws -> (Message, SNProtoContent) { let userPublicKey = Configuration.shared.storage.getUserPublicKey() + let isOpenGroupMessage = (openGroupMessageServerID != 0) // Parse the envelope let envelope = try SNProtoEnvelope.parseData(data) let storage = Configuration.shared.storage @@ -54,12 +55,16 @@ internal enum MessageReceiver { let plaintext: Data let sender: String var groupPublicKey: String? = nil - switch envelope.type { - case .unidentifiedSender: (plaintext, sender) = try decryptWithSignalProtocol(envelope: envelope, using: transaction) - case .closedGroupCiphertext: - (plaintext, sender) = try decryptWithSharedSenderKeys(envelope: envelope, using: transaction) - groupPublicKey = envelope.source - default: throw Error.unknownEnvelopeType + if isOpenGroupMessage { + (plaintext, sender) = (envelope.content!, envelope.source!) + } else { + switch envelope.type { + case .unidentifiedSender: (plaintext, sender) = try decryptWithSignalProtocol(envelope: envelope, using: transaction) + case .closedGroupCiphertext: + (plaintext, sender) = try decryptWithSharedSenderKeys(envelope: envelope, using: transaction) + groupPublicKey = envelope.source + default: throw Error.unknownEnvelopeType + } } // Don't process the envelope any further if the sender is blocked guard !isBlocked(sender) else { throw Error.senderBlocked } @@ -83,12 +88,15 @@ internal enum MessageReceiver { return nil }() if let message = message { + if isOpenGroupMessage { + guard message is VisibleMessage else { throw Error.invalidMessage } + } message.sender = sender message.recipient = userPublicKey message.sentTimestamp = envelope.timestamp message.receivedTimestamp = NSDate.millisecondTimestamp() message.groupPublicKey = groupPublicKey - message.openGroupServerMessageID = messageServerID + message.openGroupServerMessageID = openGroupMessageServerID var isValid = message.isValid if message is VisibleMessage && !isValid && proto.dataMessage?.attachments.isEmpty == false { isValid = true diff --git a/SessionMessagingKit/Sending & Receiving/MessageSender.swift b/SessionMessagingKit/Sending & Receiving/MessageSender.swift index 03f1f83ec..caad66df1 100644 --- a/SessionMessagingKit/Sending & Receiving/MessageSender.swift +++ b/SessionMessagingKit/Sending & Receiving/MessageSender.swift @@ -203,6 +203,7 @@ public final class MessageSender : NSObject { let (promise, seal) = Promise.pending() let storage = Configuration.shared.storage message.sentTimestamp = NSDate.millisecondTimestamp() + message.sender = storage.getUserPublicKey() switch destination { case .contact(_): preconditionFailure() case .closedGroup(_): preconditionFailure() @@ -215,6 +216,12 @@ public final class MessageSender : NSObject { }, completion: { }) } // Validate the message + if !(message is VisibleMessage) { + #if DEBUG + preconditionFailure() + #endif + seal.reject(Error.invalidMessage); return promise + } guard message.isValid else { seal.reject(Error.invalidMessage); return promise } // Convert the message to an open group message let (channel, server) = { () -> (UInt64, String) in diff --git a/SessionMessagingKit/Sending & Receiving/Pollers/OpenGroupPoller.swift b/SessionMessagingKit/Sending & Receiving/Pollers/OpenGroupPoller.swift index 5dccd17da..bb2994013 100644 --- a/SessionMessagingKit/Sending & Receiving/Pollers/OpenGroupPoller.swift +++ b/SessionMessagingKit/Sending & Receiving/Pollers/OpenGroupPoller.swift @@ -142,7 +142,7 @@ public final class OpenGroupPoller : NSObject { syncMessageBuilder.setSent(syncMessageSent) content.setSyncMessage(try! syncMessageBuilder.build()) } - let envelope = SNProtoEnvelope.builder(type: .ciphertext, timestamp: message.timestamp) + let envelope = SNProtoEnvelope.builder(type: .unidentifiedSender, timestamp: message.timestamp) envelope.setSource(senderPublicKey) envelope.setSourceDevice(1) envelope.setContent(try! content.build().serializedData()) @@ -150,7 +150,7 @@ public final class OpenGroupPoller : NSObject { Storage.write { transaction in transaction.setObject(senderDisplayName, forKey: senderPublicKey, inCollection: openGroup.id) let messageServerID = message.serverID - let job = MessageReceiveJob(data: try! envelope.buildSerializedData(), messageServerID: messageServerID) + let job = MessageReceiveJob(data: try! envelope.buildSerializedData(), openGroupMessageServerID: messageServerID, openGroupID: openGroup.id) Storage.write { transaction in SessionMessagingKit.JobQueue.shared.add(job, using: transaction) } diff --git a/SessionMessagingKit/Storage.swift b/SessionMessagingKit/Storage.swift index 1637cc74a..650f2acf7 100644 --- a/SessionMessagingKit/Storage.swift +++ b/SessionMessagingKit/Storage.swift @@ -45,6 +45,7 @@ public protocol SessionMessagingKitStorageProtocol { // MARK: - Open Groups func getOpenGroup(for threadID: String) -> OpenGroup? + func getThreadID(for openGroupID: String) -> String? // MARK: - Open Group Public Keys @@ -74,9 +75,9 @@ public protocol SessionMessagingKitStorageProtocol { func getReceivedMessageTimestamps(using transaction: Any) -> [UInt64] func addReceivedMessageTimestamp(_ timestamp: UInt64, using transaction: Any) /// Returns the ID of the thread. - func getOrCreateThread(for publicKey: String, groupPublicKey: String?, using transaction: Any) -> String? + func getOrCreateThread(for publicKey: String, groupPublicKey: String?, openGroupID: String?, using transaction: Any) -> String? /// Returns the ID of the `TSIncomingMessage` that was constructed. - func persist(_ message: VisibleMessage, quotedMessage: TSQuotedMessage?, linkPreview: OWSLinkPreview?, groupPublicKey: String?, using transaction: Any) -> String? + func persist(_ message: VisibleMessage, quotedMessage: TSQuotedMessage?, linkPreview: OWSLinkPreview?, groupPublicKey: String?, openGroupID: String?, using transaction: Any) -> String? /// Returns the IDs of the saved attachments. func persist(_ attachments: [VisibleMessage.Attachment], using transaction: Any) -> [String] /// Also touches the associated message.