WIP adding PNs for updated groups and changes for enabling unit testing

Started adding logic to subscribe and unsubscribe for updated group push notifications
Updated the keychain service to be injected via dependencies
Reworked the Dependencies logic to avoid a concurrent access issue
Fixed an issue where some keychain data might not be cleared in some cases
Fixed an issue where being kicked and readded to a group would make it seem to disappear ("invited" message wasn't getting created)
pull/941/head
Morgan Pretty 2 years ago
parent 0c1ea23b08
commit 32f5a18e00

@ -278,7 +278,7 @@
C32C5A88256DBCF9003C73A2 /* MessageReceiver+LegacyClosedGroups.swift in Sources */ = {isa = PBXBuildFile; fileRef = C32C5A87256DBCF9003C73A2 /* MessageReceiver+LegacyClosedGroups.swift */; }; C32C5A88256DBCF9003C73A2 /* MessageReceiver+LegacyClosedGroups.swift in Sources */ = {isa = PBXBuildFile; fileRef = C32C5A87256DBCF9003C73A2 /* MessageReceiver+LegacyClosedGroups.swift */; };
C32C5C3D256DCBAF003C73A2 /* AppReadiness.m in Sources */ = {isa = PBXBuildFile; fileRef = C33FDB75255A581000E217F9 /* AppReadiness.m */; }; C32C5C3D256DCBAF003C73A2 /* AppReadiness.m in Sources */ = {isa = PBXBuildFile; fileRef = C33FDB75255A581000E217F9 /* AppReadiness.m */; };
C32C5C46256DCBB2003C73A2 /* AppReadiness.h in Headers */ = {isa = PBXBuildFile; fileRef = C33FDB01255A580700E217F9 /* AppReadiness.h */; settings = {ATTRIBUTES = (Public, ); }; }; C32C5C46256DCBB2003C73A2 /* AppReadiness.h in Headers */ = {isa = PBXBuildFile; fileRef = C33FDB01255A580700E217F9 /* AppReadiness.h */; settings = {ATTRIBUTES = (Public, ); }; };
C32C5D83256DD5B6003C73A2 /* SSKKeychainStorage.swift in Sources */ = {isa = PBXBuildFile; fileRef = C33FDBBC255A581600E217F9 /* SSKKeychainStorage.swift */; }; C32C5D83256DD5B6003C73A2 /* KeychainStorageType.swift in Sources */ = {isa = PBXBuildFile; fileRef = C33FDBBC255A581600E217F9 /* KeychainStorageType.swift */; };
C32C5DBF256DD743003C73A2 /* GroupPoller.swift in Sources */ = {isa = PBXBuildFile; fileRef = C33FDB34255A580B00E217F9 /* GroupPoller.swift */; }; C32C5DBF256DD743003C73A2 /* GroupPoller.swift in Sources */ = {isa = PBXBuildFile; fileRef = C33FDB34255A580B00E217F9 /* GroupPoller.swift */; };
C32C5DC9256DD935003C73A2 /* ProxiedContentDownloader.swift in Sources */ = {isa = PBXBuildFile; fileRef = C33FDAF2255A580500E217F9 /* ProxiedContentDownloader.swift */; }; C32C5DC9256DD935003C73A2 /* ProxiedContentDownloader.swift in Sources */ = {isa = PBXBuildFile; fileRef = C33FDAF2255A580500E217F9 /* ProxiedContentDownloader.swift */; };
C32C5DD2256DD9E5003C73A2 /* LRUCache.swift in Sources */ = {isa = PBXBuildFile; fileRef = C33FDAFD255A580600E217F9 /* LRUCache.swift */; }; C32C5DD2256DD9E5003C73A2 /* LRUCache.swift in Sources */ = {isa = PBXBuildFile; fileRef = C33FDAFD255A580600E217F9 /* LRUCache.swift */; };
@ -1548,7 +1548,7 @@
C33FDBA8255A581500E217F9 /* LinkPreviewDraft.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = LinkPreviewDraft.swift; sourceTree = "<group>"; }; C33FDBA8255A581500E217F9 /* LinkPreviewDraft.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = LinkPreviewDraft.swift; sourceTree = "<group>"; };
C33FDBAB255A581500E217F9 /* OWSFileSystem.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = OWSFileSystem.h; sourceTree = "<group>"; }; C33FDBAB255A581500E217F9 /* OWSFileSystem.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = OWSFileSystem.h; sourceTree = "<group>"; };
C33FDBB6255A581600E217F9 /* DataSource.m */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.objc; path = DataSource.m; sourceTree = "<group>"; }; C33FDBB6255A581600E217F9 /* DataSource.m */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.objc; path = DataSource.m; sourceTree = "<group>"; };
C33FDBBC255A581600E217F9 /* SSKKeychainStorage.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = SSKKeychainStorage.swift; sourceTree = "<group>"; }; C33FDBBC255A581600E217F9 /* KeychainStorageType.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = KeychainStorageType.swift; sourceTree = "<group>"; };
C33FDBD3255A581800E217F9 /* OWSSignalAddress.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = OWSSignalAddress.swift; sourceTree = "<group>"; }; C33FDBD3255A581800E217F9 /* OWSSignalAddress.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = OWSSignalAddress.swift; sourceTree = "<group>"; };
C33FDBDE255A581900E217F9 /* PushNotificationAPI.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = PushNotificationAPI.swift; sourceTree = "<group>"; }; C33FDBDE255A581900E217F9 /* PushNotificationAPI.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = PushNotificationAPI.swift; sourceTree = "<group>"; };
C33FDC1B255A581F00E217F9 /* OWSBackgroundTask.m */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.objc; path = OWSBackgroundTask.m; sourceTree = "<group>"; }; C33FDC1B255A581F00E217F9 /* OWSBackgroundTask.m */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.objc; path = OWSBackgroundTask.m; sourceTree = "<group>"; };
@ -2747,7 +2747,6 @@
FD28A4F527EAD44C00FF65E7 /* Storage.swift */, FD28A4F527EAD44C00FF65E7 /* Storage.swift */,
C33FDBAB255A581500E217F9 /* OWSFileSystem.h */, C33FDBAB255A581500E217F9 /* OWSFileSystem.h */,
C33FDA8E255A57FD00E217F9 /* OWSFileSystem.m */, C33FDA8E255A57FD00E217F9 /* OWSFileSystem.m */,
C33FDBBC255A581600E217F9 /* SSKKeychainStorage.swift */,
); );
path = Database; path = Database;
sourceTree = "<group>"; sourceTree = "<group>";
@ -3790,6 +3789,7 @@
FDFBB74A2A1EFF4900CA7350 /* Bencode.swift */, FDFBB74A2A1EFF4900CA7350 /* Bencode.swift */,
FD97B23F2A3FEB050027DD57 /* ARC4RandomNumberGenerator.swift */, FD97B23F2A3FEB050027DD57 /* ARC4RandomNumberGenerator.swift */,
FDA8EB0F280F8238002B68E5 /* Codable+Utilities.swift */, FDA8EB0F280F8238002B68E5 /* Codable+Utilities.swift */,
C33FDBBC255A581600E217F9 /* KeychainStorageType.swift */,
FD3003692A3ADD6000B5A5FB /* CExceptionHelper.h */, FD3003692A3ADD6000B5A5FB /* CExceptionHelper.h */,
FD30036D2A3AE26000B5A5FB /* CExceptionHelper.mm */, FD30036D2A3AE26000B5A5FB /* CExceptionHelper.mm */,
FD12A84A2AD6458800EEBA0D /* DifferenceKit+Utilities.swift */, FD12A84A2AD6458800EEBA0D /* DifferenceKit+Utilities.swift */,
@ -6144,7 +6144,7 @@
FDE658A129418C7900A33BC1 /* CryptoKit+Utilities.swift in Sources */, FDE658A129418C7900A33BC1 /* CryptoKit+Utilities.swift in Sources */,
FDFD645927F26C6800808CA1 /* Array+Utilities.swift in Sources */, FDFD645927F26C6800808CA1 /* Array+Utilities.swift in Sources */,
7B1D74B027C365960030B423 /* Timer+MainThread.swift in Sources */, 7B1D74B027C365960030B423 /* Timer+MainThread.swift in Sources */,
C32C5D83256DD5B6003C73A2 /* SSKKeychainStorage.swift in Sources */, C32C5D83256DD5B6003C73A2 /* KeychainStorageType.swift in Sources */,
FD559DF52A7368CB00C7C62A /* DispatchQueue+Utilities.swift in Sources */, FD559DF52A7368CB00C7C62A /* DispatchQueue+Utilities.swift in Sources */,
FDF8488329405A12007DCAE5 /* BatchResponse.swift in Sources */, FDF8488329405A12007DCAE5 /* BatchResponse.swift in Sources */,
C3D9E39B256763C20040E4F3 /* AppContext.m in Sources */, C3D9E39B256763C20040E4F3 /* AppContext.m in Sources */,

@ -121,10 +121,11 @@ public struct SessionApp {
DDLog.flushLog() DDLog.flushLog()
SessionUtil.clearMemoryState(using: dependencies) SessionUtil.clearMemoryState(using: dependencies)
Storage.resetAllStorage() Storage.resetAllStorage(using: dependencies)
DisplayPictureManager.resetStorage(using: dependencies) DisplayPictureManager.resetStorage(using: dependencies)
Attachment.resetAttachmentStorage() Attachment.resetAttachmentStorage()
AppEnvironment.shared.notificationPresenter.clearAllNotifications() AppEnvironment.shared.notificationPresenter.clearAllNotifications()
dependencies[singleton: .keychain].removeAll()
onReset?() onReset?()
exit(0) exit(0)

@ -123,6 +123,9 @@ enum Onboarding {
) { ) {
let sessionId: SessionId = SessionId(.standard, publicKey: x25519KeyPair.publicKey) let sessionId: SessionId = SessionId(.standard, publicKey: x25519KeyPair.publicKey)
// Reset the PushNotificationAPI keys (just in case they were left over from a prior install)
PushNotificationAPI.deleteKeys(using: dependencies)
// Store the user identity information // Store the user identity information
dependencies[singleton: .storage].write { db in dependencies[singleton: .storage].write { db in
try Identity.store( try Identity.store(
@ -185,7 +188,7 @@ enum Onboarding {
// Only continue if this isn't a new account // Only continue if this isn't a new account
guard self != .register else { return } guard self != .register else { return }
// Fetch the // Fetch the profile name
Onboarding.profileNamePublisher Onboarding.profileNamePublisher
.subscribe(on: DispatchQueue.global(qos: .userInitiated), using: dependencies) .subscribe(on: DispatchQueue.global(qos: .userInitiated), using: dependencies)
.sinkUntilComplete() .sinkUntilComplete()

@ -225,6 +225,19 @@ public extension ClosedGroup {
/// Start polling /// Start polling
dependencies[singleton: .groupsPoller].startIfNeeded(for: group.id, using: dependencies) dependencies[singleton: .groupsPoller].startIfNeeded(for: group.id, using: dependencies)
/// Subscribe for group push notifications
if let token: String = dependencies[defaults: .standard, key: .deviceToken] {
try? PushNotificationAPI
.preparedSubscribe(
db,
token: Data(hex: token),
sessionId: SessionId(.group, hex: group.id),
using: dependencies
)
.send(using: dependencies)
.sinkUntilComplete()
}
} }
static func removeData( static func removeData(
@ -271,13 +284,24 @@ public extension ClosedGroup {
try? PushNotificationAPI try? PushNotificationAPI
.preparedUnsubscribeFromLegacyGroup( .preparedUnsubscribeFromLegacyGroup(
legacyGroupId: threadId, legacyGroupId: threadId,
userSessionId: userSessionId userSessionId: userSessionId,
using: dependencies
) )
.send(using: dependencies) .send(using: dependencies)
.sinkUntilComplete() .sinkUntilComplete()
case .group: case .group:
break if let token: String = dependencies[defaults: .standard, key: .deviceToken] {
try? PushNotificationAPI
.preparedUnsubscribe(
db,
token: Data(hex: token),
sessionId: userSessionId,
using: dependencies
)
.send(using: dependencies)
.sinkUntilComplete()
}
default: break default: break
} }

@ -132,6 +132,10 @@ extension MessageReceiver {
/// or modified clients /// or modified clients
let inviteSenderIsApproved: Bool = ((try? Contact.fetchOne(db, id: sender))?.isApproved == true) let inviteSenderIsApproved: Bool = ((try? Contact.fetchOne(db, id: sender))?.isApproved == true)
let threadAlreadyExisted: Bool = ((try? SessionThread.exists(db, id: message.groupSessionId.hexString)) ?? false) let threadAlreadyExisted: Bool = ((try? SessionThread.exists(db, id: message.groupSessionId.hexString)) ?? false)
let wasKickedFromGroup: Bool = SessionUtil.wasKickedFromGroup(
groupSessionId: message.groupSessionId,
using: dependencies
)
try MessageReceiver.handleNewGroup( try MessageReceiver.handleNewGroup(
db, db,
@ -155,7 +159,7 @@ extension MessageReceiver {
).upsert(db) ).upsert(db)
/// If the thread didn't already exist then insert an 'invited' info message /// If the thread didn't already exist then insert an 'invited' info message
guard !threadAlreadyExisted else { return } guard !threadAlreadyExisted || wasKickedFromGroup else { return }
let interaction: Interaction = try Interaction( let interaction: Interaction = try Interaction(
threadId: message.groupSessionId.hexString, threadId: message.groupSessionId.hexString,

@ -263,7 +263,8 @@ extension MessageReceiver {
) )
.asRequest(of: String.self) .asRequest(of: String.self)
.fetchSet(db) .fetchSet(db)
.inserting(legacyGroupSessionId) // Insert the new key just to be sure .inserting(legacyGroupSessionId), // Insert the new key just to be sure
using: dependencies
)? )?
.send(using: dependencies) .send(using: dependencies)
.subscribe(on: DispatchQueue.global(qos: .default), using: dependencies) .subscribe(on: DispatchQueue.global(qos: .default), using: dependencies)

@ -75,12 +75,17 @@ extension MessageSender {
try createdInfo.members.forEach { try $0.insert(db) } try createdInfo.members.forEach { try $0.insert(db) }
// Prepare the notification subscription // Prepare the notification subscription
let preparedNotificationSubscription = try? PushNotificationAPI var preparedNotificationSubscription: HTTP.PreparedRequest<PushNotificationAPI.SubscribeResponse>?
.preparedSubscribe(
db, if let token: String = dependencies[defaults: .standard, key: .deviceToken] {
sessionId: createdInfo.groupSessionId, preparedNotificationSubscription = try? PushNotificationAPI
using: dependencies .preparedSubscribe(
) db,
token: Data(hex: token),
sessionId: createdInfo.groupSessionId,
using: dependencies
)
}
return ( return (
createdInfo.groupSessionId, createdInfo.groupSessionId,

@ -144,7 +144,8 @@ extension MessageSender {
try? PushNotificationAPI try? PushNotificationAPI
.preparedSubscribeToLegacyGroups( .preparedSubscribeToLegacyGroups(
userSessionId: userSessionId, userSessionId: userSessionId,
legacyGroupIds: allActiveLegacyGroupIds legacyGroupIds: allActiveLegacyGroupIds,
using: dependencies
)? )?
.map { _, _ in () } .map { _, _ in () }
) )

@ -1,4 +1,6 @@
// Copyright © 2022 Rangeproof Pty Ltd. All rights reserved. // Copyright © 2022 Rangeproof Pty Ltd. All rights reserved.
//
// stringlint:disable
import Foundation import Foundation
import Combine import Combine
@ -6,9 +8,14 @@ import GRDB
import SessionSnodeKit import SessionSnodeKit
import SessionUtilitiesKit import SessionUtilitiesKit
// MARK: - KeychainStorage
public extension KeychainStorage.ServiceKey { static let pushNotificationAPI: Self = "PNKeyChainService" }
public extension KeychainStorage.DataKey { static let pushNotificationEncryptionKey: Self = "PNEncryptionKeyKey" }
// MARK: - PushNotificationAPI
public enum PushNotificationAPI { public enum PushNotificationAPI {
private static let keychainService: String = "PNKeyChainService"
private static let encryptionKeyKey: String = "PNEncryptionKeyKey"
private static let encryptionKeyLength: Int = 32 private static let encryptionKeyLength: Int = 32
private static let maxRetryCount: Int = 4 private static let maxRetryCount: Int = 4
private static let tokenExpirationInterval: TimeInterval = (12 * 60 * 60) private static let tokenExpirationInterval: TimeInterval = (12 * 60 * 60)
@ -23,10 +30,11 @@ public enum PushNotificationAPI {
public static func subscribeAll( public static func subscribeAll(
token: Data, token: Data,
isForcedUpdate: Bool, isForcedUpdate: Bool,
using dependencies: Dependencies = Dependencies() using dependencies: Dependencies
) -> AnyPublisher<Void, Error> { ) -> AnyPublisher<Void, Error> {
typealias SubscribeAllPreparedRequests = ( typealias SubscribeAllPreparedRequests = (
HTTP.PreparedRequest<PushNotificationAPI.SubscribeResponse>, HTTP.PreparedRequest<PushNotificationAPI.SubscribeResponse>,
[HTTP.PreparedRequest<PushNotificationAPI.SubscribeResponse>],
HTTP.PreparedRequest<PushNotificationAPI.LegacyPushServerResponse>? HTTP.PreparedRequest<PushNotificationAPI.LegacyPushServerResponse>?
) )
let hexEncodedToken: String = token.toHexString() let hexEncodedToken: String = token.toHexString()
@ -47,6 +55,7 @@ public enum PushNotificationAPI {
let preparedUserRequest = try PushNotificationAPI let preparedUserRequest = try PushNotificationAPI
.preparedSubscribe( .preparedSubscribe(
db, db,
token: token,
sessionId: userSessionId, sessionId: userSessionId,
using: dependencies using: dependencies
) )
@ -59,6 +68,21 @@ public enum PushNotificationAPI {
dependencies[defaults: .standard, key: .isUsingFullAPNs] = true dependencies[defaults: .standard, key: .isUsingFullAPNs] = true
} }
) )
let preparedGroupRequests = try ClosedGroup
.select(.threadId)
.filter(ClosedGroup.Columns.threadId.like("\(SessionId.Prefix.group.rawValue)%"))
.filter(ClosedGroup.Columns.shouldPoll)
.asRequest(of: String.self)
.fetchSet(db)
.map { groupId in
try PushNotificationAPI
.preparedSubscribe(
db,
token: token,
sessionId: SessionId(.group, hex: groupId),
using: dependencies
)
}
let preparedLegacyGroupRequest = try PushNotificationAPI let preparedLegacyGroupRequest = try PushNotificationAPI
.preparedSubscribeToLegacyGroups( .preparedSubscribeToLegacyGroups(
forced: true, forced: true,
@ -78,10 +102,11 @@ public enum PushNotificationAPI {
return ( return (
preparedUserRequest, preparedUserRequest,
preparedGroupRequests,
preparedLegacyGroupRequest preparedLegacyGroupRequest
) )
} }
.flatMap { userRequest, legacyGroupRequest -> AnyPublisher<Void, Error> in .flatMap { userRequest, preparedGroupRequests, legacyGroupRequest -> AnyPublisher<Void, Error> in
Publishers Publishers
.MergeMany( .MergeMany(
[ [
@ -94,7 +119,16 @@ public enum PushNotificationAPI {
.send(using: dependencies) .send(using: dependencies)
.map { _, _ in () } .map { _, _ in () }
.eraseToAnyPublisher() .eraseToAnyPublisher()
].compactMap { $0 } ]
.appending(
contentsOf: preparedGroupRequests.map { request in
request
.send(using: dependencies)
.map { _, _ in () }
.eraseToAnyPublisher()
}
)
.compactMap { $0 }
) )
.collect() .collect()
.map { _ in () } .map { _ in () }
@ -105,27 +139,23 @@ public enum PushNotificationAPI {
public static func unsubscribeAll( public static func unsubscribeAll(
token: Data, token: Data,
using dependencies: Dependencies = Dependencies() using dependencies: Dependencies
) -> AnyPublisher<Void, Error> { ) -> AnyPublisher<Void, Error> {
typealias UnsubscribeAllPreparedRequests = ( typealias UnsubscribeAllPreparedRequests = (
HTTP.PreparedRequest<PushNotificationAPI.UnsubscribeResponse>, HTTP.PreparedRequest<PushNotificationAPI.UnsubscribeResponse>,
[HTTP.PreparedRequest<PushNotificationAPI.UnsubscribeResponse>],
[HTTP.PreparedRequest<PushNotificationAPI.LegacyPushServerResponse>] [HTTP.PreparedRequest<PushNotificationAPI.LegacyPushServerResponse>]
) )
return dependencies[singleton: .storage] return dependencies[singleton: .storage]
.readPublisher(using: dependencies) { db -> UnsubscribeAllPreparedRequests in .readPublisher(using: dependencies) { db -> UnsubscribeAllPreparedRequests in
guard let userED25519KeyPair: KeyPair = Identity.fetchUserEd25519KeyPair(db) else {
throw SnodeAPIError.noKeyPair
}
let userSessionId: SessionId = getUserSessionId(db, using: dependencies) let userSessionId: SessionId = getUserSessionId(db, using: dependencies)
let preparedUserRequest = try PushNotificationAPI let preparedUserRequest = try PushNotificationAPI
.preparedUnsubscribe( .preparedUnsubscribe(
db, db,
token: token, token: token,
sessionId: userSessionId, sessionId: userSessionId,
subkey: nil, using: dependencies
ed25519KeyPair: userED25519KeyPair
) )
.handleEvents( .handleEvents(
receiveOutput: { _, response in receiveOutput: { _, response in
@ -134,6 +164,20 @@ public enum PushNotificationAPI {
dependencies[defaults: .standard, key: .deviceToken] = nil dependencies[defaults: .standard, key: .deviceToken] = nil
} }
) )
let preparedGroupUnsubscribeRequests = (try? ClosedGroup
.select(.threadId)
.filter(ClosedGroup.Columns.threadId.like("\(SessionId.Prefix.group.rawValue)%"))
.asRequest(of: String.self)
.fetchSet(db))
.defaulting(to: [])
.compactMap { groupId in
try? PushNotificationAPI.preparedUnsubscribe(
db,
token: token,
sessionId: SessionId(.group, hex: groupId),
using: dependencies
)
}
// FIXME: Remove this once legacy groups are deprecated // FIXME: Remove this once legacy groups are deprecated
let preparedLegacyUnsubscribeRequests = (try? ClosedGroup let preparedLegacyUnsubscribeRequests = (try? ClosedGroup
@ -145,13 +189,14 @@ public enum PushNotificationAPI {
.compactMap { legacyGroupId in .compactMap { legacyGroupId in
try? PushNotificationAPI.preparedUnsubscribeFromLegacyGroup( try? PushNotificationAPI.preparedUnsubscribeFromLegacyGroup(
legacyGroupId: legacyGroupId, legacyGroupId: legacyGroupId,
userSessionId: userSessionId userSessionId: userSessionId,
using: dependencies
) )
} }
return (preparedUserRequest, preparedLegacyUnsubscribeRequests) return (preparedUserRequest, preparedGroupUnsubscribeRequests, preparedLegacyUnsubscribeRequests)
} }
.flatMap { preparedUserRequest, preparedLegacyUnsubscribeRequests in .flatMap { preparedUserRequest, preparedGroupUnsubscribeRequests, preparedLegacyUnsubscribeRequests in
// FIXME: Remove this once legacy groups are deprecated // FIXME: Remove this once legacy groups are deprecated
/// Unsubscribe from all legacy groups (including ones the user is no longer a member of, just in case) /// Unsubscribe from all legacy groups (including ones the user is no longer a member of, just in case)
Publishers Publishers
@ -161,9 +206,24 @@ public enum PushNotificationAPI {
.receive(on: DispatchQueue.global(qos: .userInitiated), using: dependencies) .receive(on: DispatchQueue.global(qos: .userInitiated), using: dependencies)
.sinkUntilComplete() .sinkUntilComplete()
return preparedUserRequest.send(using: dependencies) return Publishers
.MergeMany(
[
preparedUserRequest
.send(using: dependencies)
.map { _, _ in () }
.eraseToAnyPublisher()
]
.appending(
contentsOf: preparedGroupUnsubscribeRequests.map { request in
request
.send(using: dependencies)
.map { _, _ in () }
.eraseToAnyPublisher()
}
)
)
} }
.map { _ in () }
.eraseToAnyPublisher() .eraseToAnyPublisher()
} }
@ -171,13 +231,11 @@ public enum PushNotificationAPI {
public static func preparedSubscribe( public static func preparedSubscribe(
_ db: Database, _ db: Database,
token: Data,
sessionId: SessionId, sessionId: SessionId,
using dependencies: Dependencies = Dependencies() using dependencies: Dependencies
) throws -> HTTP.PreparedRequest<SubscribeResponse> { ) throws -> HTTP.PreparedRequest<SubscribeResponse> {
guard guard dependencies[defaults: .standard, key: .isUsingFullAPNs] else { throw HTTPError.invalidRequest }
dependencies[defaults: .standard, key: .isUsingFullAPNs],
let token: String = dependencies[defaults: .standard, key: .deviceToken]
else { throw HTTPError.invalidRequest }
guard let notificationsEncryptionKey: Data = try? getOrGenerateEncryptionKey(using: dependencies) else { guard let notificationsEncryptionKey: Data = try? getOrGenerateEncryptionKey(using: dependencies) else {
SNLog("Unable to retrieve PN encryption key.") SNLog("Unable to retrieve PN encryption key.")
@ -192,7 +250,7 @@ public enum PushNotificationAPI {
body: SubscribeRequest( body: SubscribeRequest(
namespaces: { namespaces: {
switch sessionId.prefix { switch sessionId.prefix {
case .group: return [.default] case .group: return [.groupMessages]
default: return [.default, .configConvoInfoVolatile] default: return [.default, .configConvoInfoVolatile]
} }
}(), }(),
@ -201,7 +259,7 @@ public enum PushNotificationAPI {
// 'generic' notification being shown when receiving things like typing indicator updates // 'generic' notification being shown when receiving things like typing indicator updates
includeMessageData: true, includeMessageData: true,
serviceInfo: ServiceInfo( serviceInfo: ServiceInfo(
token: token token: token.toHexString()
), ),
notificationsEncryptionKey: notificationsEncryptionKey, notificationsEncryptionKey: notificationsEncryptionKey,
authMethod: try Authentication.with( authMethod: try Authentication.with(
@ -215,7 +273,8 @@ public enum PushNotificationAPI {
) )
), ),
responseType: SubscribeResponse.self, responseType: SubscribeResponse.self,
retryCount: PushNotificationAPI.maxRetryCount retryCount: PushNotificationAPI.maxRetryCount,
using: dependencies
) )
.handleEvents( .handleEvents(
receiveOutput: { _, response in receiveOutput: { _, response in
@ -236,9 +295,7 @@ public enum PushNotificationAPI {
_ db: Database, _ db: Database,
token: Data, token: Data,
sessionId: SessionId, sessionId: SessionId,
subkey: String?, using dependencies: Dependencies
ed25519KeyPair: KeyPair,
using dependencies: Dependencies = Dependencies()
) throws -> HTTP.PreparedRequest<UnsubscribeResponse> { ) throws -> HTTP.PreparedRequest<UnsubscribeResponse> {
return try PushNotificationAPI return try PushNotificationAPI
.prepareRequest( .prepareRequest(
@ -260,7 +317,8 @@ public enum PushNotificationAPI {
) )
), ),
responseType: UnsubscribeResponse.self, responseType: UnsubscribeResponse.self,
retryCount: PushNotificationAPI.maxRetryCount retryCount: PushNotificationAPI.maxRetryCount,
using: dependencies
) )
.handleEvents( .handleEvents(
receiveOutput: { _, response in receiveOutput: { _, response in
@ -284,7 +342,7 @@ public enum PushNotificationAPI {
recipient: String, recipient: String,
with message: String, with message: String,
maxRetryCount: Int? = nil, maxRetryCount: Int? = nil,
using dependencies: Dependencies = Dependencies() using dependencies: Dependencies
) throws -> HTTP.PreparedRequest<LegacyPushServerResponse> { ) throws -> HTTP.PreparedRequest<LegacyPushServerResponse> {
return try PushNotificationAPI return try PushNotificationAPI
.prepareRequest( .prepareRequest(
@ -297,7 +355,8 @@ public enum PushNotificationAPI {
) )
), ),
responseType: LegacyPushServerResponse.self, responseType: LegacyPushServerResponse.self,
retryCount: (maxRetryCount ?? PushNotificationAPI.maxRetryCount) retryCount: (maxRetryCount ?? PushNotificationAPI.maxRetryCount),
using: dependencies
) )
.handleEvents( .handleEvents(
receiveOutput: { _, response in receiveOutput: { _, response in
@ -322,7 +381,7 @@ public enum PushNotificationAPI {
token: String? = nil, token: String? = nil,
userSessionId: SessionId, userSessionId: SessionId,
legacyGroupIds: Set<String>, legacyGroupIds: Set<String>,
using dependencies: Dependencies = Dependencies() using dependencies: Dependencies
) throws -> HTTP.PreparedRequest<LegacyPushServerResponse>? { ) throws -> HTTP.PreparedRequest<LegacyPushServerResponse>? {
let isUsingFullAPNs = dependencies[defaults: .standard, key: .isUsingFullAPNs] let isUsingFullAPNs = dependencies[defaults: .standard, key: .isUsingFullAPNs]
@ -346,7 +405,8 @@ public enum PushNotificationAPI {
) )
), ),
responseType: LegacyPushServerResponse.self, responseType: LegacyPushServerResponse.self,
retryCount: PushNotificationAPI.maxRetryCount retryCount: PushNotificationAPI.maxRetryCount,
using: dependencies
) )
.handleEvents( .handleEvents(
receiveOutput: { _, response in receiveOutput: { _, response in
@ -367,7 +427,7 @@ public enum PushNotificationAPI {
public static func preparedUnsubscribeFromLegacyGroup( public static func preparedUnsubscribeFromLegacyGroup(
legacyGroupId: String, legacyGroupId: String,
userSessionId: SessionId, userSessionId: SessionId,
using dependencies: Dependencies = Dependencies() using dependencies: Dependencies
) throws -> HTTP.PreparedRequest<LegacyPushServerResponse> { ) throws -> HTTP.PreparedRequest<LegacyPushServerResponse> {
return try PushNotificationAPI return try PushNotificationAPI
.prepareRequest( .prepareRequest(
@ -380,7 +440,8 @@ public enum PushNotificationAPI {
) )
), ),
responseType: LegacyPushServerResponse.self, responseType: LegacyPushServerResponse.self,
retryCount: PushNotificationAPI.maxRetryCount retryCount: PushNotificationAPI.maxRetryCount,
using: dependencies
) )
.handleEvents( .handleEvents(
receiveOutput: { _, response in receiveOutput: { _, response in
@ -401,7 +462,7 @@ public enum PushNotificationAPI {
public static func processNotification( public static func processNotification(
notificationContent: UNNotificationContent, notificationContent: UNNotificationContent,
dependencies: Dependencies = Dependencies() using dependencies: Dependencies
) -> (data: Data?, metadata: NotificationMetadata, result: ProcessResult) { ) -> (data: Data?, metadata: NotificationMetadata, result: ProcessResult) {
// Make sure the notification is from the updated push server // Make sure the notification is from the updated push server
guard notificationContent.userInfo["spns"] != nil else { guard notificationContent.userInfo["spns"] != nil else {
@ -467,9 +528,9 @@ public enum PushNotificationAPI {
@discardableResult private static func getOrGenerateEncryptionKey(using dependencies: Dependencies) throws -> Data { @discardableResult private static func getOrGenerateEncryptionKey(using dependencies: Dependencies) throws -> Data {
do { do {
var encryptionKey: Data = try SSKDefaultKeychainStorage.shared.data( var encryptionKey: Data = try dependencies[singleton: .keychain].data(
forService: keychainService, forService: .pushNotificationAPI,
key: encryptionKeyKey key: .pushNotificationEncryptionKey
) )
defer { encryptionKey.resetBytes(in: 0..<encryptionKey.count) } defer { encryptionKey.resetBytes(in: 0..<encryptionKey.count) }
@ -486,10 +547,10 @@ public enum PushNotificationAPI {
.tryGenerate(.randomBytes(numberBytes: encryptionKeyLength))) .tryGenerate(.randomBytes(numberBytes: encryptionKeyLength)))
defer { keySpec.resetBytes(in: 0..<keySpec.count) } // Reset content immediately after use defer { keySpec.resetBytes(in: 0..<keySpec.count) } // Reset content immediately after use
try SSKDefaultKeychainStorage.shared.set( try dependencies[singleton: .keychain].set(
data: keySpec, data: keySpec,
service: keychainService, service: .pushNotificationAPI,
key: encryptionKeyKey key: .pushNotificationEncryptionKey
) )
return keySpec return keySpec
} }
@ -514,46 +575,9 @@ public enum PushNotificationAPI {
} }
} }
} }
// MARK: - Convenience
private static func send<T: Encodable>( public static func deleteKeys(using dependencies: Dependencies = Dependencies()) {
request: PushNotificationAPIRequest<T>, try? dependencies[singleton: .keychain].remove(service: .pushNotificationAPI, key: .pushNotificationEncryptionKey)
using dependencies: Dependencies
) -> AnyPublisher<(ResponseInfoType, Data?), Error> {
guard
let url: URL = URL(string: "\(request.endpoint.server)/\(request.endpoint.path)"),
let payload: Data = try? JSONEncoder(using: dependencies).encode(request.body)
else {
return Fail(error: HTTPError.invalidJSON)
.eraseToAnyPublisher()
}
guard Features.useOnionRequests else {
return HTTP
.execute(
.post,
"\(request.endpoint.server)/\(request.endpoint.path)",
body: payload
)
.map { response in (HTTP.ResponseInfo(code: -1, headers: [:]), response) }
.eraseToAnyPublisher()
}
var urlRequest: URLRequest = URLRequest(url: url)
urlRequest.httpMethod = "POST"
urlRequest.allHTTPHeaderFields = [ HTTPHeader.contentType: "application/json" ]
urlRequest.httpBody = payload
return dependencies[singleton: .network]
.send(
.onionRequest(
urlRequest,
to: request.endpoint.server,
with: request.endpoint.serverPublicKey
)
)
.eraseToAnyPublisher()
} }
// MARK: - Convenience // MARK: - Convenience
@ -563,7 +587,7 @@ public enum PushNotificationAPI {
responseType: R.Type, responseType: R.Type,
retryCount: Int = 0, retryCount: Int = 0,
timeout: TimeInterval = HTTP.defaultTimeout, timeout: TimeInterval = HTTP.defaultTimeout,
using dependencies: Dependencies = Dependencies() using dependencies: Dependencies
) throws -> HTTP.PreparedRequest<R> { ) throws -> HTTP.PreparedRequest<R> {
return HTTP.PreparedRequest<R>( return HTTP.PreparedRequest<R>(
request: request, request: request,

@ -68,7 +68,8 @@ public final class NotificationServiceExtension: UNNotificationServiceExtension
} }
let (maybeData, metadata, result) = PushNotificationAPI.processNotification( let (maybeData, metadata, result) = PushNotificationAPI.processNotification(
notificationContent: notificationContent notificationContent: notificationContent,
using: dependencies
) )
guard guard

@ -1,112 +0,0 @@
//
// Copyright (c) 2018 Open Whisper Systems. All rights reserved.
//
// stringlint:disable
import Foundation
import SAMKeychain
public enum KeychainStorageError: Error {
case failure(code: Int32?, description: String)
public var code: Int32? {
switch self {
case .failure(let code, _): return code
}
}
}
// MARK: -
@objc public protocol SSKKeychainStorage: AnyObject {
@objc func string(forService service: String, key: String) throws -> String
@objc(setString:service:key:error:) func set(string: String, service: String, key: String) throws
@objc func data(forService service: String, key: String) throws -> Data
@objc func set(data: Data, service: String, key: String) throws
@objc func remove(service: String, key: String) throws
}
// MARK: -
@objc
public class SSKDefaultKeychainStorage: NSObject, SSKKeychainStorage {
@objc public static let shared = SSKDefaultKeychainStorage()
// Force usage as a singleton
override private init() {
super.init()
}
@objc public func string(forService service: String, key: String) throws -> String {
var error: NSError?
let result = SAMKeychain.password(forService: service, account: key, error: &error)
if let error = error {
throw KeychainStorageError.failure(code: Int32(error.code), description: "\(logTag) error retrieving string: \(error)")
}
guard let string = result else {
throw KeychainStorageError.failure(code: nil, description: "\(logTag) could not retrieve string")
}
return string
}
@objc public func set(string: String, service: String, key: String) throws {
SAMKeychain.setAccessibilityType(kSecAttrAccessibleAfterFirstUnlockThisDeviceOnly)
var error: NSError?
let result = SAMKeychain.setPassword(string, forService: service, account: key, error: &error)
if let error = error {
throw KeychainStorageError.failure(code: Int32(error.code), description: "\(logTag) error setting string: \(error)")
}
guard result else {
throw KeychainStorageError.failure(code: nil, description: "\(logTag) could not set string")
}
}
@objc public func data(forService service: String, key: String) throws -> Data {
var error: NSError?
let result = SAMKeychain.passwordData(forService: service, account: key, error: &error)
if let error = error {
throw KeychainStorageError.failure(code: Int32(error.code), description: "\(logTag) error retrieving data: \(error)")
}
guard let data = result else {
throw KeychainStorageError.failure(code: nil, description: "\(logTag) could not retrieve data")
}
return data
}
@objc public func set(data: Data, service: String, key: String) throws {
SAMKeychain.setAccessibilityType(kSecAttrAccessibleAfterFirstUnlockThisDeviceOnly)
var error: NSError?
let result = SAMKeychain.setPasswordData(data, forService: service, account: key, error: &error)
if let error = error {
throw KeychainStorageError.failure(code: Int32(error.code), description: "\(logTag) error setting data: \(error)")
}
guard result else {
throw KeychainStorageError.failure(code: nil, description: "\(logTag) could not set data")
}
}
@objc public func remove(service: String, key: String) throws {
var error: NSError?
let result = SAMKeychain.deletePassword(forService: service, account: key, error: &error)
if let error = error {
// If deletion failed because the specified item could not be found in the keychain, consider it success.
if error.code == errSecItemNotFound {
return
}
throw KeychainStorageError.failure(code: Int32(error.code), description: "\(logTag) error removing data: \(error)")
}
guard result else {
throw KeychainStorageError.failure(code: nil, description: "\(logTag) could not remove data")
}
}
}

@ -21,13 +21,16 @@ public extension Singleton {
) )
} }
// MARK: - KeychainStorage
public extension KeychainStorage.ServiceKey { static let storage: Self = "TSKeyChainService" }
public extension KeychainStorage.DataKey { static let dbCipherKeySpec: Self = "GRDBDatabaseCipherKeySpec" }
// MARK: - Storage // MARK: - Storage
open class Storage { open class Storage {
public static let queuePrefix: String = "SessionDatabase" public static let queuePrefix: String = "SessionDatabase"
private static let dbFileName: String = "Session.sqlite" private static let dbFileName: String = "Session.sqlite"
private static let keychainService: String = "TSKeyChainService"
private static let dbCipherKeySpecKey: String = "GRDBDatabaseCipherKeySpec"
private static let kSQLCipherKeySpecLength: Int = 48 private static let kSQLCipherKeySpecLength: Int = 48
private static let writeWarningThreadshold: TimeInterval = 3 private static let writeWarningThreadshold: TimeInterval = 3
@ -394,15 +397,15 @@ open class Storage {
// MARK: - Security // MARK: - Security
private static func getDatabaseCipherKeySpec() throws -> Data { private static func getDatabaseCipherKeySpec(using dependencies: Dependencies = Dependencies()) throws -> Data {
return try SSKDefaultKeychainStorage.shared.data(forService: keychainService, key: dbCipherKeySpecKey) return try dependencies[singleton: .keychain].data(forService: .storage, key: .dbCipherKeySpec)
} }
@discardableResult private static func getOrGenerateDatabaseKeySpec( @discardableResult private static func getOrGenerateDatabaseKeySpec(
using dependencies: Dependencies = Dependencies() using dependencies: Dependencies = Dependencies()
) throws -> Data { ) throws -> Data {
do { do {
var keySpec: Data = try getDatabaseCipherKeySpec() var keySpec: Data = try getDatabaseCipherKeySpec(using: dependencies)
defer { keySpec.resetBytes(in: 0..<keySpec.count) } defer { keySpec.resetBytes(in: 0..<keySpec.count) }
guard keySpec.count == kSQLCipherKeySpecLength else { throw StorageError.invalidKeySpec } guard keySpec.count == kSQLCipherKeySpecLength else { throw StorageError.invalidKeySpec }
@ -427,7 +430,7 @@ open class Storage {
var keySpec: Data = try dependencies[singleton: .crypto].tryGenerate(.randomBytes(numberBytes: kSQLCipherKeySpecLength)) var keySpec: Data = try dependencies[singleton: .crypto].tryGenerate(.randomBytes(numberBytes: kSQLCipherKeySpecLength))
defer { keySpec.resetBytes(in: 0..<keySpec.count) } // Reset content immediately after use defer { keySpec.resetBytes(in: 0..<keySpec.count) } // Reset content immediately after use
try SSKDefaultKeychainStorage.shared.set(data: keySpec, service: keychainService, key: dbCipherKeySpecKey) try dependencies[singleton: .keychain].set(data: keySpec, service: .storage, key: .dbCipherKeySpec)
return keySpec return keySpec
} }
catch { catch {
@ -487,7 +490,7 @@ open class Storage {
Storage.internalHasCreatedValidInstance.mutate { $0 = false } Storage.internalHasCreatedValidInstance.mutate { $0 = false }
deleteDatabaseFiles() deleteDatabaseFiles()
try? deleteDbKeys() try? deleteDbKeys(using: dependencies)
} }
public static func reconfigureDatabase(using dependencies: Dependencies = Dependencies()) { public static func reconfigureDatabase(using dependencies: Dependencies = Dependencies()) {
@ -508,8 +511,8 @@ open class Storage {
OWSFileSystem.deleteFile(databasePathWal) OWSFileSystem.deleteFile(databasePathWal)
} }
private static func deleteDbKeys() throws { private static func deleteDbKeys(using dependencies: Dependencies = Dependencies()) throws {
try SSKDefaultKeychainStorage.shared.remove(service: keychainService, key: dbCipherKeySpecKey) try dependencies[singleton: .keychain].remove(service: .storage, key: .dbCipherKeySpec)
} }
// MARK: - Logging Functions // MARK: - Logging Functions

@ -15,15 +15,27 @@ public class Dependencies {
// MARK: - Subscript Access // MARK: - Subscript Access
public subscript<S>(singleton singleton: SingletonConfig<S>) -> S { public subscript<S>(singleton singleton: SingletonConfig<S>) -> S {
getValueSettingIfNull(singleton: singleton, &Dependencies.singletonInstances) guard let value: S = (Dependencies.singletonInstances.wrappedValue[singleton.identifier] as? S) else {
let value: S = singleton.createInstance(self)
Dependencies.singletonInstances.mutate { $0[singleton.identifier] = value }
return value
}
return value
} }
public subscript<M, I>(cache cache: CacheConfig<M, I>) -> I { public subscript<M, I>(cache cache: CacheConfig<M, I>) -> I {
getValueSettingIfNull(cache: cache, &Dependencies.cacheInstances) getValueSettingIfNull(cache: cache)
} }
public subscript(defaults defaults: UserDefaultsConfig) -> UserDefaultsType { public subscript(defaults defaults: UserDefaultsConfig) -> UserDefaultsType {
getValueSettingIfNull(defaults: defaults, &Dependencies.userDefaultsInstances) guard let value: UserDefaultsType = Dependencies.userDefaultsInstances.wrappedValue[defaults.identifier] else {
let value: UserDefaultsType = defaults.createInstance(self)
Dependencies.userDefaultsInstances.mutate { $0[defaults.identifier] = value }
return value
}
return value
} }
// MARK: - Timing and Async Handling // MARK: - Timing and Async Handling
@ -55,7 +67,7 @@ public class Dependencies {
/// the below code we first call `getValueSettingIfNull` to ensure we have a proper instance stored /// the below code we first call `getValueSettingIfNull` to ensure we have a proper instance stored
/// in `Dependencies.cacheInstances` so that we can be reliably certail we aren't accessing some /// in `Dependencies.cacheInstances` so that we can be reliably certail we aren't accessing some
/// random instance that will go out of memory as soon as the mutation is completed /// random instance that will go out of memory as soon as the mutation is completed
getValueSettingIfNull(cache: cache, &Dependencies.cacheInstances) getValueSettingIfNull(cache: cache)
let cacheWrapper: Atomic<MutableCacheType> = ( let cacheWrapper: Atomic<MutableCacheType> = (
Dependencies.cacheInstances.wrappedValue[cache.identifier] ?? Dependencies.cacheInstances.wrappedValue[cache.identifier] ??
@ -77,7 +89,7 @@ public class Dependencies {
/// the below code we first call `getValueSettingIfNull` to ensure we have a proper instance stored /// the below code we first call `getValueSettingIfNull` to ensure we have a proper instance stored
/// in `Dependencies.cacheInstances` so that we can be reliably certail we aren't accessing some /// in `Dependencies.cacheInstances` so that we can be reliably certail we aren't accessing some
/// random instance that will go out of memory as soon as the mutation is completed /// random instance that will go out of memory as soon as the mutation is completed
getValueSettingIfNull(cache: cache, &Dependencies.cacheInstances) getValueSettingIfNull(cache: cache)
let cacheWrapper: Atomic<MutableCacheType> = ( let cacheWrapper: Atomic<MutableCacheType> = (
Dependencies.cacheInstances.wrappedValue[cache.identifier] ?? Dependencies.cacheInstances.wrappedValue[cache.identifier] ??
@ -106,45 +118,16 @@ public class Dependencies {
// MARK: - Instance upserting // MARK: - Instance upserting
@discardableResult private func getValueSettingIfNull<S>( @discardableResult private func getValueSettingIfNull<M, I>(cache: CacheConfig<M, I>) -> I {
singleton: SingletonConfig<S>, guard let value: M = (Dependencies.cacheInstances.wrappedValue[cache.identifier]?.wrappedValue as? M) else {
_ store: inout Atomic<[String: Any]>
) -> S {
guard let value: S = (store.wrappedValue[singleton.identifier] as? S) else {
let value: S = singleton.createInstance(self)
store.mutate { $0[singleton.identifier] = value }
return value
}
return value
}
@discardableResult private func getValueSettingIfNull<M, I>(
cache: CacheConfig<M, I>,
_ store: inout Atomic<[String: Atomic<MutableCacheType>]>
) -> I {
guard let value: M = (store.wrappedValue[cache.identifier]?.wrappedValue as? M) else {
let value: M = cache.createInstance(self) let value: M = cache.createInstance(self)
let mutableInstance: MutableCacheType = cache.mutableInstance(value) let mutableInstance: MutableCacheType = cache.mutableInstance(value)
store.mutate { $0[cache.identifier] = Atomic(mutableInstance) } Dependencies.cacheInstances.mutate { $0[cache.identifier] = Atomic(mutableInstance) }
return cache.immutableInstance(value) return cache.immutableInstance(value)
} }
return cache.immutableInstance(value) return cache.immutableInstance(value)
} }
@discardableResult private func getValueSettingIfNull(
defaults: UserDefaultsConfig,
_ store: inout Atomic<[String: UserDefaultsType]>
) -> UserDefaultsType {
guard let value: UserDefaultsType = store.wrappedValue[defaults.identifier] else {
let value: UserDefaultsType = defaults.createInstance(self)
store.mutate { $0[defaults.identifier] = value }
return value
}
return value
}
} }
// MARK: - Storage Setting Convenience // MARK: - Storage Setting Convenience

@ -0,0 +1,195 @@
// Copyright © 2023 Rangeproof Pty Ltd. All rights reserved.
//
// stringlint:disable
import Foundation
import SAMKeychain
// MARK: - Singleton
public extension Singleton {
static let keychain: SingletonConfig<KeychainStorageType> = Dependencies.create(
identifier: "keychain",
createInstance: { _ in KeychainStorage() }
)
}
public enum KeychainStorageError: Error {
case failure(code: Int32?, description: String)
public var code: Int32? {
switch self {
case .failure(let code, _): return code
}
}
}
// MARK: - KeychainStorageType
public protocol KeychainStorageType {
func string(forService service: KeychainStorage.ServiceKey, key: KeychainStorage.StringKey) throws -> String
func set(string: String, service: KeychainStorage.ServiceKey, key: KeychainStorage.StringKey) throws
func remove(service: KeychainStorage.ServiceKey, key: KeychainStorage.StringKey) throws
func data(forService service: KeychainStorage.ServiceKey, key: KeychainStorage.DataKey) throws -> Data
func set(data: Data, service: KeychainStorage.ServiceKey, key: KeychainStorage.DataKey) throws
func remove(service: KeychainStorage.ServiceKey, key: KeychainStorage.DataKey) throws
func removeAll()
}
// MARK: - KeychainStorage
public class KeychainStorage: KeychainStorageType {
public func string(forService service: KeychainStorage.ServiceKey, key: KeychainStorage.StringKey) throws -> String {
var error: NSError?
let result: String? = SAMKeychain.password(forService: service.rawValue, account: key.rawValue, error: &error)
switch (error, result) {
case (.some(let error), _):
throw KeychainStorageError.failure(
code: Int32(error.code),
description: "[KeychainStorage] Error retrieving string: \(error)"
)
case (_, .none):
throw KeychainStorageError.failure(code: nil, description: "[KeychainStorage] Could not retrieve string")
case (_, .some(let string)): return string
}
}
public func set(string: String, service: KeychainStorage.ServiceKey, key: KeychainStorage.StringKey) throws {
SAMKeychain.setAccessibilityType(kSecAttrAccessibleAfterFirstUnlockThisDeviceOnly)
var error: NSError?
let result: Bool = SAMKeychain.setPassword(string, forService: service.rawValue, account: key.rawValue, error: &error)
switch (error, result) {
case (.some(let error), _):
throw KeychainStorageError.failure(
code: Int32(error.code),
description: "[KeychainStorage] Error setting string: \(error)"
)
case (_, false):
throw KeychainStorageError.failure(code: nil, description: "[KeychainStorage] Could not set string")
case (_, true): break
}
}
public func remove(service: KeychainStorage.ServiceKey, key: KeychainStorage.StringKey) throws {
try remove(service: service.rawValue, key: key.rawValue)
}
public func data(forService service: KeychainStorage.ServiceKey, key: KeychainStorage.DataKey) throws -> Data {
var error: NSError?
let result: Data? = SAMKeychain.passwordData(forService: service.rawValue, account: key.rawValue, error: &error)
switch (error, result) {
case (.some(let error), _):
throw KeychainStorageError.failure(
code: Int32(error.code),
description: "[KeychainStorage] Error retrieving data: \(error)"
)
case (_, .none):
throw KeychainStorageError.failure(code: nil, description: "[KeychainStorage] Could not retrieve data")
case (_, .some(let data)): return data
}
}
public func set(data: Data, service: KeychainStorage.ServiceKey, key: KeychainStorage.DataKey) throws {
SAMKeychain.setAccessibilityType(kSecAttrAccessibleAfterFirstUnlockThisDeviceOnly)
var error: NSError?
let result: Bool = SAMKeychain.setPasswordData(data, forService: service.rawValue, account: key.rawValue, error: &error)
switch (error, result) {
case (.some(let error), _):
throw KeychainStorageError.failure(
code: Int32(error.code),
description: "[KeychainStorage] Error setting data: \(error)"
)
case (_, false):
throw KeychainStorageError.failure(code: nil, description: "[KeychainStorage] Could not set data")
case (_, true): break
}
}
public func remove(service: KeychainStorage.ServiceKey, key: KeychainStorage.DataKey) throws {
try remove(service: service.rawValue, key: key.rawValue)
}
private func remove(service: String, key: String) throws {
var error: NSError?
let result: Bool = SAMKeychain.deletePassword(forService: service, account: key, error: &error)
switch (error, result) {
case (.some(let error), _):
/// If deletion failed because the specified item could not be found in the keychain, consider it success
guard error.code != errSecItemNotFound else { return }
throw KeychainStorageError.failure(
code: Int32(error.code),
description: "[KeychainStorage] Error removing data: \(error)"
)
case (_, false):
throw KeychainStorageError.failure(code: nil, description: "[KeychainStorage] Could not remove data")
case (_, true): break
}
}
public func removeAll() {
let allData: [[String: Any]] = SAMKeychain.allAccounts().defaulting(to: [])
allData.forEach { keychainEntry in
guard
let service: String = keychainEntry[kSAMKeychainWhereKey] as? String,
let key: String = keychainEntry[kSAMKeychainAccountKey] as? String
else { return }
try? remove(service: service, key: key)
}
}
}
// MARK: - Keys
public extension KeychainStorage {
struct ServiceKey: RawRepresentable, ExpressibleByStringLiteral, Hashable {
public let rawValue: String
public init(_ rawValue: String) { self.rawValue = rawValue }
public init?(rawValue: String) { self.rawValue = rawValue }
public init(stringLiteral value: String) { self.init(value) }
public init(unicodeScalarLiteral value: String) { self.init(value) }
public init(extendedGraphemeClusterLiteral value: String) { self.init(value) }
}
struct DataKey: RawRepresentable, ExpressibleByStringLiteral, Hashable {
public let rawValue: String
public init(_ rawValue: String) { self.rawValue = rawValue }
public init?(rawValue: String) { self.rawValue = rawValue }
public init(stringLiteral value: String) { self.init(value) }
public init(unicodeScalarLiteral value: String) { self.init(value) }
public init(extendedGraphemeClusterLiteral value: String) { self.init(value) }
}
struct StringKey: RawRepresentable, ExpressibleByStringLiteral, Hashable {
public let rawValue: String
public init(_ rawValue: String) { self.rawValue = rawValue }
public init?(rawValue: String) { self.rawValue = rawValue }
public init(stringLiteral value: String) { self.init(value) }
public init(unicodeScalarLiteral value: String) { self.init(value) }
public init(extendedGraphemeClusterLiteral value: String) { self.init(value) }
}
}

@ -14,7 +14,13 @@ public class TestDependencies: Dependencies {
// MARK: - Subscript Access // MARK: - Subscript Access
override public subscript<S>(singleton singleton: SingletonConfig<S>) -> S { override public subscript<S>(singleton singleton: SingletonConfig<S>) -> S {
return getValueSettingIfNull(singleton: singleton, &singletonInstances) guard let value: S = (singletonInstances[singleton.identifier] as? S) else {
let value: S = singleton.createInstance(self)
singletonInstances[singleton.identifier] = value
return value
}
return value
} }
public subscript<S>(singleton singleton: SingletonConfig<S>) -> S? { public subscript<S>(singleton singleton: SingletonConfig<S>) -> S? {
@ -23,7 +29,14 @@ public class TestDependencies: Dependencies {
} }
override public subscript<M, I>(cache cache: CacheConfig<M, I>) -> I { override public subscript<M, I>(cache cache: CacheConfig<M, I>) -> I {
return getValueSettingIfNull(cache: cache, &cacheInstances) guard let value: M = (cacheInstances[cache.identifier] as? M) else {
let value: M = cache.createInstance(self)
let mutableInstance: MutableCacheType = cache.mutableInstance(value)
cacheInstances[cache.identifier] = mutableInstance
return cache.immutableInstance(value)
}
return cache.immutableInstance(value)
} }
public subscript<M, I>(cache cache: CacheConfig<M, I>) -> M? { public subscript<M, I>(cache cache: CacheConfig<M, I>) -> M? {
@ -32,7 +45,13 @@ public class TestDependencies: Dependencies {
} }
override public subscript(defaults defaults: UserDefaultsConfig) -> UserDefaultsType { override public subscript(defaults defaults: UserDefaultsConfig) -> UserDefaultsType {
return getValueSettingIfNull(defaults: defaults, &defaultsInstances) guard let value: UserDefaultsType = defaultsInstances[defaults.identifier] else {
let value: UserDefaultsType = defaults.createInstance(self)
defaultsInstances[defaults.identifier] = value
return value
}
return value
} }
public subscript(defaults defaults: UserDefaultsConfig) -> UserDefaultsType? { public subscript(defaults defaults: UserDefaultsConfig) -> UserDefaultsType? {
@ -142,48 +161,6 @@ public class TestDependencies: Dependencies {
return result.map { elements.remove($0) } return result.map { elements.remove($0) }
} }
// MARK: - Instance upserting
@discardableResult private func getValueSettingIfNull<S>(
singleton: SingletonConfig<S>,
_ store: inout [String: Any]
) -> S {
guard let value: S = (store[singleton.identifier] as? S) else {
let value: S = singleton.createInstance(self)
store[singleton.identifier] = value
return value
}
return value
}
@discardableResult private func getValueSettingIfNull<M, I>(
cache: CacheConfig<M, I>,
_ store: inout [String: MutableCacheType]
) -> I {
guard let value: M = (store[cache.identifier] as? M) else {
let value: M = cache.createInstance(self)
let mutableInstance: MutableCacheType = cache.mutableInstance(value)
store[cache.identifier] = mutableInstance
return cache.immutableInstance(value)
}
return cache.immutableInstance(value)
}
@discardableResult private func getValueSettingIfNull(
defaults: UserDefaultsConfig,
_ store: inout [String: (any UserDefaultsType)]
) -> UserDefaultsType {
guard let value: UserDefaultsType = store[defaults.identifier] else {
let value: UserDefaultsType = defaults.createInstance(self)
store[defaults.identifier] = value
return value
}
return value
}
} }
// MARK: - TestState Convenience // MARK: - TestState Convenience

Loading…
Cancel
Save