From 924b96856f7a99fdfc877ed41d7fa0068e4f3ded Mon Sep 17 00:00:00 2001
From: Simon Frei <freisim93@gmail.com>
Date: Mon, 22 Mar 2021 21:50:19 +0100
Subject: [PATCH] lib: Handle adding enc folders on an existing conn (fixes
 #7509) (#7510)

---
 lib/connections/service.go       |  8 +----
 lib/model/fakeconns_test.go      |  2 +-
 lib/model/mocks/model.go         | 12 +++----
 lib/model/model.go               | 41 +++++++++++++-----------
 lib/model/model_test.go          | 20 ++++++------
 lib/protocol/benchmark_test.go   |  6 ++--
 lib/protocol/common_test.go      |  2 +-
 lib/protocol/encryption.go       | 55 ++++++++++++++++++++++++--------
 lib/protocol/mocks/connection.go | 39 ++++++++++++++++++++++
 lib/protocol/protocol.go         | 23 ++++++-------
 lib/protocol/protocol_test.go    | 39 ++++++++++++++--------
 11 files changed, 160 insertions(+), 87 deletions(-)

diff --git a/lib/connections/service.go b/lib/connections/service.go
index e61a8c082..fa0ee7093 100644
--- a/lib/connections/service.go
+++ b/lib/connections/service.go
@@ -335,13 +335,7 @@ func (s *service) handle(ctx context.Context) error {
 		isLAN := s.isLAN(c.RemoteAddr())
 		rd, wr := s.limiter.getLimiters(remoteID, c, isLAN)
 
-		var protoConn protocol.Connection
-		passwords := s.cfg.FolderPasswords(remoteID)
-		if len(passwords) > 0 {
-			protoConn = protocol.NewEncryptedConnection(passwords, remoteID, rd, wr, c, s.model, c, deviceCfg.Compression)
-		} else {
-			protoConn = protocol.NewConnection(remoteID, rd, wr, c, s.model, c, deviceCfg.Compression)
-		}
+		protoConn := protocol.NewConnection(remoteID, rd, wr, c, s.model, c, deviceCfg.Compression, s.cfg.FolderPasswords(remoteID))
 
 		l.Infof("Established secure connection to %s at %s", remoteID, c)
 
diff --git a/lib/model/fakeconns_test.go b/lib/model/fakeconns_test.go
index c50137865..b1a8e91ef 100644
--- a/lib/model/fakeconns_test.go
+++ b/lib/model/fakeconns_test.go
@@ -33,7 +33,7 @@ func newFakeConnection(id protocol.DeviceID, model Model) *fakeConnection {
 	})
 	f.IDReturns(id)
 	f.CloseCalls(func(err error) {
-		model.Closed(f, err)
+		model.Closed(id, err)
 		f.ClosedReturns(true)
 	})
 	return f
diff --git a/lib/model/mocks/model.go b/lib/model/mocks/model.go
index 9e4048b19..8491d87d1 100644
--- a/lib/model/mocks/model.go
+++ b/lib/model/mocks/model.go
@@ -43,10 +43,10 @@ type Model struct {
 		arg1 string
 		arg2 string
 	}
-	ClosedStub        func(protocol.Connection, error)
+	ClosedStub        func(protocol.DeviceID, error)
 	closedMutex       sync.RWMutex
 	closedArgsForCall []struct {
-		arg1 protocol.Connection
+		arg1 protocol.DeviceID
 		arg2 error
 	}
 	ClusterConfigStub        func(protocol.DeviceID, protocol.ClusterConfig) error
@@ -684,10 +684,10 @@ func (fake *Model) BringToFrontArgsForCall(i int) (string, string) {
 	return argsForCall.arg1, argsForCall.arg2
 }
 
-func (fake *Model) Closed(arg1 protocol.Connection, arg2 error) {
+func (fake *Model) Closed(arg1 protocol.DeviceID, arg2 error) {
 	fake.closedMutex.Lock()
 	fake.closedArgsForCall = append(fake.closedArgsForCall, struct {
-		arg1 protocol.Connection
+		arg1 protocol.DeviceID
 		arg2 error
 	}{arg1, arg2})
 	stub := fake.ClosedStub
@@ -704,13 +704,13 @@ func (fake *Model) ClosedCallCount() int {
 	return len(fake.closedArgsForCall)
 }
 
-func (fake *Model) ClosedCalls(stub func(protocol.Connection, error)) {
+func (fake *Model) ClosedCalls(stub func(protocol.DeviceID, error)) {
 	fake.closedMutex.Lock()
 	defer fake.closedMutex.Unlock()
 	fake.ClosedStub = stub
 }
 
-func (fake *Model) ClosedArgsForCall(i int) (protocol.Connection, error) {
+func (fake *Model) ClosedArgsForCall(i int) (protocol.DeviceID, error) {
 	fake.closedMutex.RLock()
 	defer fake.closedMutex.RUnlock()
 	argsForCall := fake.closedArgsForCall[i]
diff --git a/lib/model/model.go b/lib/model/model.go
index 8f65ae373..c0024a02b 100644
--- a/lib/model/model.go
+++ b/lib/model/model.go
@@ -293,7 +293,7 @@ func (m *model) initFolders(cfg config.Configuration) error {
 	ignoredDevices := observedDeviceSet(m.cfg.IgnoredDevices())
 	m.cleanPending(cfg.DeviceMap(), cfg.FolderMap(), ignoredDevices, nil)
 
-	m.resendClusterConfig(clusterConfigDevices.AsSlice())
+	m.sendClusterConfig(clusterConfigDevices.AsSlice())
 	return nil
 }
 
@@ -1510,7 +1510,7 @@ func (m *model) ccCheckEncryption(fcfg config.FolderConfiguration, folderDevice
 			m.fmut.Unlock()
 			// We can only announce ourselfs once we have the token,
 			// thus we need to resend CCs now that we have it.
-			m.resendClusterConfig(fcfg.DeviceIDs())
+			m.sendClusterConfig(fcfg.DeviceIDs())
 			return nil
 		}
 	}
@@ -1520,7 +1520,7 @@ func (m *model) ccCheckEncryption(fcfg config.FolderConfiguration, folderDevice
 	return nil
 }
 
-func (m *model) resendClusterConfig(ids []protocol.DeviceID) {
+func (m *model) sendClusterConfig(ids []protocol.DeviceID) {
 	if len(ids) == 0 {
 		return
 	}
@@ -1534,7 +1534,8 @@ func (m *model) resendClusterConfig(ids []protocol.DeviceID) {
 	m.pmut.RUnlock()
 	// Generating cluster-configs acquires fmut -> must happen outside of pmut.
 	for _, conn := range ccConns {
-		cm := m.generateClusterConfig(conn.ID())
+		cm, passwords := m.generateClusterConfig(conn.ID())
+		conn.SetFolderPasswords(passwords)
 		go conn.ClusterConfig(cm)
 	}
 }
@@ -1728,9 +1729,7 @@ func (m *model) introduceDevice(device protocol.Device, introducerCfg config.Dev
 }
 
 // Closed is called when a connection has been closed
-func (m *model) Closed(conn protocol.Connection, err error) {
-	device := conn.ID()
-
+func (m *model) Closed(device protocol.DeviceID, err error) {
 	m.pmut.Lock()
 	conn, ok := m.conn[device]
 	if !ok {
@@ -2247,7 +2246,8 @@ func (m *model) AddConnection(conn protocol.Connection, hello protocol.Hello) {
 	m.pmut.Unlock()
 
 	// Acquires fmut, so has to be done outside of pmut.
-	cm := m.generateClusterConfig(deviceID)
+	cm, passwords := m.generateClusterConfig(deviceID)
+	conn.SetFolderPasswords(passwords)
 	conn.ClusterConfig(cm)
 
 	if (device.Name == "" || m.cfg.Options().OverwriteRemoteDevNames) && hello.DeviceName != "" {
@@ -2407,15 +2407,17 @@ func (m *model) numHashers(folder string) int {
 	return 1
 }
 
-// generateClusterConfig returns a ClusterConfigMessage that is correct for
-// the given peer device
-func (m *model) generateClusterConfig(device protocol.DeviceID) protocol.ClusterConfig {
+// generateClusterConfig returns a ClusterConfigMessage that is correct and the
+// set of folder passwords for the given peer device
+func (m *model) generateClusterConfig(device protocol.DeviceID) (protocol.ClusterConfig, map[string]string) {
 	var message protocol.ClusterConfig
 
 	m.fmut.RLock()
 	defer m.fmut.RUnlock()
 
-	for _, folderCfg := range m.cfg.FolderList() {
+	folders := m.cfg.FolderList()
+	passwords := make(map[string]string, len(folders))
+	for _, folderCfg := range folders {
 		if !folderCfg.SharedWith(device) {
 			continue
 		}
@@ -2448,8 +2450,8 @@ func (m *model) generateClusterConfig(device protocol.DeviceID) protocol.Cluster
 		// another cluster config once the folder is started.
 		protocolFolder.Paused = folderCfg.Paused || fs == nil
 
-		for _, device := range folderCfg.Devices {
-			deviceCfg, _ := m.cfg.Device(device.DeviceID)
+		for _, folderDevice := range folderCfg.Devices {
+			deviceCfg, _ := m.cfg.Device(folderDevice.DeviceID)
 
 			protocolDevice := protocol.Device{
 				ID:          deviceCfg.DeviceID,
@@ -2462,8 +2464,11 @@ func (m *model) generateClusterConfig(device protocol.DeviceID) protocol.Cluster
 
 			if deviceCfg.DeviceID == m.id && hasEncryptionToken {
 				protocolDevice.EncryptionPasswordToken = encryptionToken
-			} else if device.EncryptionPassword != "" {
-				protocolDevice.EncryptionPasswordToken = protocol.PasswordToken(folderCfg.ID, device.EncryptionPassword)
+			} else if folderDevice.EncryptionPassword != "" {
+				protocolDevice.EncryptionPasswordToken = protocol.PasswordToken(folderCfg.ID, folderDevice.EncryptionPassword)
+				if folderDevice.DeviceID == device {
+					passwords[folderCfg.ID] = folderDevice.EncryptionPassword
+				}
 			}
 
 			if fs != nil {
@@ -2482,7 +2487,7 @@ func (m *model) generateClusterConfig(device protocol.DeviceID) protocol.Cluster
 		message.Folders = append(message.Folders, protocolFolder)
 	}
 
-	return message
+	return message, passwords
 }
 
 func (m *model) State(folder string) (string, time.Time, error) {
@@ -2891,7 +2896,7 @@ func (m *model) CommitConfiguration(from, to config.Configuration) bool {
 	}
 	m.pmut.RUnlock()
 	// Generating cluster-configs acquires fmut -> must happen outside of pmut.
-	m.resendClusterConfig(clusterConfigDevices.AsSlice())
+	m.sendClusterConfig(clusterConfigDevices.AsSlice())
 
 	ignoredDevices := observedDeviceSet(to.IgnoredDevices)
 	m.cleanPending(toDevices, toFolders, ignoredDevices, removedFolders)
diff --git a/lib/model/model_test.go b/lib/model/model_test.go
index 736dfe628..096d2aa9b 100644
--- a/lib/model/model_test.go
+++ b/lib/model/model_test.go
@@ -341,7 +341,7 @@ func TestDeviceRename(t *testing.T) {
 		t.Errorf("Device already has a name")
 	}
 
-	m.Closed(conn, protocol.ErrTimeout)
+	m.Closed(conn.ID(), protocol.ErrTimeout)
 	hello.DeviceName = "tester"
 	m.AddConnection(conn, hello)
 
@@ -349,7 +349,7 @@ func TestDeviceRename(t *testing.T) {
 		t.Errorf("Device did not get a name")
 	}
 
-	m.Closed(conn, protocol.ErrTimeout)
+	m.Closed(conn.ID(), protocol.ErrTimeout)
 	hello.DeviceName = "tester2"
 	m.AddConnection(conn, hello)
 
@@ -367,7 +367,7 @@ func TestDeviceRename(t *testing.T) {
 		t.Errorf("Device name not saved in config")
 	}
 
-	m.Closed(conn, protocol.ErrTimeout)
+	m.Closed(conn.ID(), protocol.ErrTimeout)
 
 	waiter, err := cfg.Modify(func(cfg *config.Configuration) {
 		cfg.Options.OverwriteRemoteDevNames = true
@@ -428,7 +428,7 @@ func TestClusterConfig(t *testing.T) {
 	m.ServeBackground()
 	defer cleanupModel(m)
 
-	cm := m.generateClusterConfig(device2)
+	cm, _ := m.generateClusterConfig(device2)
 
 	if l := len(cm.Folders); l != 2 {
 		t.Fatalf("Incorrect number of folders %d != 2", l)
@@ -853,7 +853,7 @@ func TestIssue4897(t *testing.T) {
 	defer cleanupModel(m)
 	cancel()
 
-	cm := m.generateClusterConfig(device1)
+	cm, _ := m.generateClusterConfig(device1)
 	if l := len(cm.Folders); l != 1 {
 		t.Errorf("Cluster config contains %v folders, expected 1", l)
 	}
@@ -873,7 +873,7 @@ func TestIssue5063(t *testing.T) {
 	for _, c := range m.conn {
 		conn := c.(*fakeConnection)
 		conn.CloseCalls(func(_ error) {})
-		defer m.Closed(c, errStopped) // to unblock deferred m.Stop()
+		defer m.Closed(c.ID(), errStopped) // to unblock deferred m.Stop()
 	}
 	m.pmut.Unlock()
 
@@ -2428,8 +2428,8 @@ func TestNoRequestsFromPausedDevices(t *testing.T) {
 		t.Errorf("should have two available")
 	}
 
-	m.Closed(newFakeConnection(device1, m), errDeviceUnknown)
-	m.Closed(newFakeConnection(device2, m), errDeviceUnknown)
+	m.Closed(device1, errDeviceUnknown)
+	m.Closed(device2, errDeviceUnknown)
 
 	avail = m.testAvailability("default", file, file.Blocks[0])
 	if len(avail) != 0 {
@@ -3171,7 +3171,7 @@ func TestConnCloseOnRestart(t *testing.T) {
 
 	br := &testutils.BlockingRW{}
 	nw := &testutils.NoopRW{}
-	m.AddConnection(protocol.NewConnection(device1, br, nw, testutils.NoopCloser{}, m, new(protocolmocks.ConnectionInfo), protocol.CompressionNever), protocol.Hello{})
+	m.AddConnection(protocol.NewConnection(device1, br, nw, testutils.NoopCloser{}, m, new(protocolmocks.ConnectionInfo), protocol.CompressionNever, nil), protocol.Hello{})
 	m.pmut.RLock()
 	if len(m.closed) != 1 {
 		t.Fatalf("Expected just one conn (len(m.conn) == %v)", len(m.conn))
@@ -4142,7 +4142,7 @@ func TestCCFolderNotRunning(t *testing.T) {
 	defer cleanupModelAndRemoveDir(m, tfs.URI())
 
 	// A connection can happen before all the folders are started.
-	cc := m.generateClusterConfig(device1)
+	cc, _ := m.generateClusterConfig(device1)
 	if l := len(cc.Folders); l != 1 {
 		t.Fatalf("Expected 1 folder in CC, got %v", l)
 	}
diff --git a/lib/protocol/benchmark_test.go b/lib/protocol/benchmark_test.go
index 5b005f16b..1b3bd304e 100644
--- a/lib/protocol/benchmark_test.go
+++ b/lib/protocol/benchmark_test.go
@@ -60,9 +60,9 @@ func benchmarkRequestsTLS(b *testing.B, conn0, conn1 net.Conn) {
 
 func benchmarkRequestsConnPair(b *testing.B, conn0, conn1 net.Conn) {
 	// Start up Connections on them
-	c0 := NewConnection(LocalDeviceID, conn0, conn0, testutils.NoopCloser{}, new(fakeModel), new(mockedConnectionInfo), CompressionMetadata)
+	c0 := NewConnection(LocalDeviceID, conn0, conn0, testutils.NoopCloser{}, new(fakeModel), new(mockedConnectionInfo), CompressionMetadata, nil)
 	c0.Start()
-	c1 := NewConnection(LocalDeviceID, conn1, conn1, testutils.NoopCloser{}, new(fakeModel), new(mockedConnectionInfo), CompressionMetadata)
+	c1 := NewConnection(LocalDeviceID, conn1, conn1, testutils.NoopCloser{}, new(fakeModel), new(mockedConnectionInfo), CompressionMetadata, nil)
 	c1.Start()
 
 	// Satisfy the assertions in the protocol by sending an initial cluster config
@@ -188,7 +188,7 @@ func (m *fakeModel) ClusterConfig(deviceID DeviceID, config ClusterConfig) error
 	return nil
 }
 
-func (m *fakeModel) Closed(conn Connection, err error) {
+func (m *fakeModel) Closed(DeviceID, error) {
 }
 
 func (m *fakeModel) DownloadProgress(deviceID DeviceID, folder string, updates []FileDownloadProgressUpdate) error {
diff --git a/lib/protocol/common_test.go b/lib/protocol/common_test.go
index 29a5b1d29..6e5c7f02c 100644
--- a/lib/protocol/common_test.go
+++ b/lib/protocol/common_test.go
@@ -49,7 +49,7 @@ func (t *TestModel) Request(deviceID DeviceID, folder, name string, blockNo, siz
 	return &fakeRequestResponse{buf}, nil
 }
 
-func (t *TestModel) Closed(conn Connection, err error) {
+func (t *TestModel) Closed(_ DeviceID, err error) {
 	t.closedErr = err
 	close(t.closedCh)
 }
diff --git a/lib/protocol/encryption.go b/lib/protocol/encryption.go
index 0ccac6f42..f5cdebb1f 100644
--- a/lib/protocol/encryption.go
+++ b/lib/protocol/encryption.go
@@ -14,6 +14,7 @@ import (
 	"fmt"
 	"io"
 	"strings"
+	"sync"
 	"time"
 
 	"github.com/gogo/protobuf/proto"
@@ -41,11 +42,11 @@ const (
 // must decrypt those and answer requests by encrypting the data.
 type encryptedModel struct {
 	model      Model
-	folderKeys map[string]*[keySize]byte // folder ID -> key
+	folderKeys *folderKeyRegistry
 }
 
 func (e encryptedModel) Index(deviceID DeviceID, folder string, files []FileInfo) error {
-	if folderKey, ok := e.folderKeys[folder]; ok {
+	if folderKey, ok := e.folderKeys.get(folder); ok {
 		// incoming index data to be decrypted
 		if err := decryptFileInfos(files, folderKey); err != nil {
 			return err
@@ -55,7 +56,7 @@ func (e encryptedModel) Index(deviceID DeviceID, folder string, files []FileInfo
 }
 
 func (e encryptedModel) IndexUpdate(deviceID DeviceID, folder string, files []FileInfo) error {
-	if folderKey, ok := e.folderKeys[folder]; ok {
+	if folderKey, ok := e.folderKeys.get(folder); ok {
 		// incoming index data to be decrypted
 		if err := decryptFileInfos(files, folderKey); err != nil {
 			return err
@@ -65,7 +66,7 @@ func (e encryptedModel) IndexUpdate(deviceID DeviceID, folder string, files []Fi
 }
 
 func (e encryptedModel) Request(deviceID DeviceID, folder, name string, blockNo, size int32, offset int64, hash []byte, weakHash uint32, fromTemporary bool) (RequestResponse, error) {
-	folderKey, ok := e.folderKeys[folder]
+	folderKey, ok := e.folderKeys.get(folder)
 	if !ok {
 		return e.model.Request(deviceID, folder, name, blockNo, size, offset, hash, weakHash, fromTemporary)
 	}
@@ -123,7 +124,7 @@ func (e encryptedModel) Request(deviceID DeviceID, folder, name string, blockNo,
 }
 
 func (e encryptedModel) DownloadProgress(deviceID DeviceID, folder string, updates []FileDownloadProgressUpdate) error {
-	if _, ok := e.folderKeys[folder]; !ok {
+	if _, ok := e.folderKeys.get(folder); !ok {
 		return e.model.DownloadProgress(deviceID, folder, updates)
 	}
 
@@ -135,42 +136,46 @@ func (e encryptedModel) ClusterConfig(deviceID DeviceID, config ClusterConfig) e
 	return e.model.ClusterConfig(deviceID, config)
 }
 
-func (e encryptedModel) Closed(conn Connection, err error) {
-	e.model.Closed(conn, err)
+func (e encryptedModel) Closed(device DeviceID, err error) {
+	e.model.Closed(device, err)
 }
 
 // The encryptedConnection sits between the model and the encrypted device. It
 // encrypts outgoing metadata and decrypts incoming responses.
 type encryptedConnection struct {
 	ConnectionInfo
-	conn       Connection
-	folderKeys map[string]*[keySize]byte // folder ID -> key
+	conn       *rawConnection
+	folderKeys *folderKeyRegistry
 }
 
 func (e encryptedConnection) Start() {
 	e.conn.Start()
 }
 
+func (e encryptedConnection) SetFolderPasswords(passwords map[string]string) {
+	e.folderKeys.setPasswords(passwords)
+}
+
 func (e encryptedConnection) ID() DeviceID {
 	return e.conn.ID()
 }
 
 func (e encryptedConnection) Index(ctx context.Context, folder string, files []FileInfo) error {
-	if folderKey, ok := e.folderKeys[folder]; ok {
+	if folderKey, ok := e.folderKeys.get(folder); ok {
 		encryptFileInfos(files, folderKey)
 	}
 	return e.conn.Index(ctx, folder, files)
 }
 
 func (e encryptedConnection) IndexUpdate(ctx context.Context, folder string, files []FileInfo) error {
-	if folderKey, ok := e.folderKeys[folder]; ok {
+	if folderKey, ok := e.folderKeys.get(folder); ok {
 		encryptFileInfos(files, folderKey)
 	}
 	return e.conn.IndexUpdate(ctx, folder, files)
 }
 
 func (e encryptedConnection) Request(ctx context.Context, folder string, name string, blockNo int, offset int64, size int, hash []byte, weakHash uint32, fromTemporary bool) ([]byte, error) {
-	folderKey, ok := e.folderKeys[folder]
+	folderKey, ok := e.folderKeys.get(folder)
 	if !ok {
 		return e.conn.Request(ctx, folder, name, blockNo, offset, size, hash, weakHash, fromTemporary)
 	}
@@ -205,7 +210,7 @@ func (e encryptedConnection) Request(ctx context.Context, folder string, name st
 }
 
 func (e encryptedConnection) DownloadProgress(ctx context.Context, folder string, updates []FileDownloadProgressUpdate) {
-	if _, ok := e.folderKeys[folder]; !ok {
+	if _, ok := e.folderKeys.get(folder); !ok {
 		e.conn.DownloadProgress(ctx, folder, updates)
 	}
 
@@ -590,3 +595,27 @@ func isEncryptedParentFromComponents(pathComponents []string) bool {
 	}
 	return true
 }
+
+type folderKeyRegistry struct {
+	keys map[string]*[keySize]byte // folder ID -> key
+	mut  sync.RWMutex
+}
+
+func newFolderKeyRegistry(passwords map[string]string) *folderKeyRegistry {
+	return &folderKeyRegistry{
+		keys: keysFromPasswords(passwords),
+	}
+}
+
+func (r *folderKeyRegistry) get(folder string) (*[keySize]byte, bool) {
+	r.mut.RLock()
+	key, ok := r.keys[folder]
+	r.mut.RUnlock()
+	return key, ok
+}
+
+func (r *folderKeyRegistry) setPasswords(passwords map[string]string) {
+	r.mut.Lock()
+	r.keys = keysFromPasswords(passwords)
+	r.mut.Unlock()
+}
diff --git a/lib/protocol/mocks/connection.go b/lib/protocol/mocks/connection.go
index c2c8cd381..a07ce4c2f 100644
--- a/lib/protocol/mocks/connection.go
+++ b/lib/protocol/mocks/connection.go
@@ -135,6 +135,11 @@ type Connection struct {
 		result1 []byte
 		result2 error
 	}
+	SetFolderPasswordsStub        func(map[string]string)
+	setFolderPasswordsMutex       sync.RWMutex
+	setFolderPasswordsArgsForCall []struct {
+		arg1 map[string]string
+	}
 	StartStub        func()
 	startMutex       sync.RWMutex
 	startArgsForCall []struct {
@@ -817,6 +822,38 @@ func (fake *Connection) RequestReturnsOnCall(i int, result1 []byte, result2 erro
 	}{result1, result2}
 }
 
+func (fake *Connection) SetFolderPasswords(arg1 map[string]string) {
+	fake.setFolderPasswordsMutex.Lock()
+	fake.setFolderPasswordsArgsForCall = append(fake.setFolderPasswordsArgsForCall, struct {
+		arg1 map[string]string
+	}{arg1})
+	stub := fake.SetFolderPasswordsStub
+	fake.recordInvocation("SetFolderPasswords", []interface{}{arg1})
+	fake.setFolderPasswordsMutex.Unlock()
+	if stub != nil {
+		fake.SetFolderPasswordsStub(arg1)
+	}
+}
+
+func (fake *Connection) SetFolderPasswordsCallCount() int {
+	fake.setFolderPasswordsMutex.RLock()
+	defer fake.setFolderPasswordsMutex.RUnlock()
+	return len(fake.setFolderPasswordsArgsForCall)
+}
+
+func (fake *Connection) SetFolderPasswordsCalls(stub func(map[string]string)) {
+	fake.setFolderPasswordsMutex.Lock()
+	defer fake.setFolderPasswordsMutex.Unlock()
+	fake.SetFolderPasswordsStub = stub
+}
+
+func (fake *Connection) SetFolderPasswordsArgsForCall(i int) map[string]string {
+	fake.setFolderPasswordsMutex.RLock()
+	defer fake.setFolderPasswordsMutex.RUnlock()
+	argsForCall := fake.setFolderPasswordsArgsForCall[i]
+	return argsForCall.arg1
+}
+
 func (fake *Connection) Start() {
 	fake.startMutex.Lock()
 	fake.startArgsForCall = append(fake.startArgsForCall, struct {
@@ -1080,6 +1117,8 @@ func (fake *Connection) Invocations() map[string][][]interface{} {
 	defer fake.remoteAddrMutex.RUnlock()
 	fake.requestMutex.RLock()
 	defer fake.requestMutex.RUnlock()
+	fake.setFolderPasswordsMutex.RLock()
+	defer fake.setFolderPasswordsMutex.RUnlock()
 	fake.startMutex.RLock()
 	defer fake.startMutex.RUnlock()
 	fake.statisticsMutex.RLock()
diff --git a/lib/protocol/protocol.go b/lib/protocol/protocol.go
index 284df12d1..188822d37 100644
--- a/lib/protocol/protocol.go
+++ b/lib/protocol/protocol.go
@@ -126,8 +126,8 @@ type Model interface {
 	Request(deviceID DeviceID, folder, name string, blockNo, size int32, offset int64, hash []byte, weakHash uint32, fromTemporary bool) (RequestResponse, error)
 	// A cluster configuration message was received
 	ClusterConfig(deviceID DeviceID, config ClusterConfig) error
-	// The peer device closed the connection
-	Closed(conn Connection, err error)
+	// The peer device closed the connection or an error occurred
+	Closed(device DeviceID, err error)
 	// The peer device sent progress updates for the files it is currently downloading
 	DownloadProgress(deviceID DeviceID, folder string, updates []FileDownloadProgressUpdate) error
 }
@@ -140,6 +140,7 @@ type RequestResponse interface {
 
 type Connection interface {
 	Start()
+	SetFolderPasswords(passwords map[string]string)
 	Close(err error)
 	ID() DeviceID
 	Index(ctx context.Context, folder string, files []FileInfo) error
@@ -225,24 +226,16 @@ const (
 // Should not be modified in production code, just for testing.
 var CloseTimeout = 10 * time.Second
 
-func NewConnection(deviceID DeviceID, reader io.Reader, writer io.Writer, closer io.Closer, receiver Model, connInfo ConnectionInfo, compress Compression) Connection {
-	receiver = nativeModel{receiver}
-	rc := newRawConnection(deviceID, reader, writer, closer, receiver, connInfo, compress)
-	return wireFormatConnection{rc}
-}
-
-func NewEncryptedConnection(passwords map[string]string, deviceID DeviceID, reader io.Reader, writer io.Writer, closer io.Closer, receiver Model, connInfo ConnectionInfo, compress Compression) Connection {
-	keys := keysFromPasswords(passwords)
-
+func NewConnection(deviceID DeviceID, reader io.Reader, writer io.Writer, closer io.Closer, receiver Model, connInfo ConnectionInfo, compress Compression, passwords map[string]string) Connection {
 	// Encryption / decryption is first (outermost) before conversion to
 	// native path formats.
 	nm := nativeModel{receiver}
-	em := encryptedModel{model: nm, folderKeys: keys}
+	em := &encryptedModel{model: nm, folderKeys: newFolderKeyRegistry(passwords)}
 
 	// We do the wire format conversion first (outermost) so that the
 	// metadata is in wire format when it reaches the encryption step.
 	rc := newRawConnection(deviceID, reader, writer, closer, em, connInfo, compress)
-	ec := encryptedConnection{ConnectionInfo: rc, conn: rc, folderKeys: keys}
+	ec := encryptedConnection{ConnectionInfo: rc, conn: rc, folderKeys: em.folderKeys}
 	wc := wireFormatConnection{ec}
 
 	return wc
@@ -748,6 +741,8 @@ func (c *rawConnection) writerLoop() {
 }
 
 func (c *rawConnection) writeMessage(msg message) error {
+	msgContext, _ := messageContext(msg)
+	l.Debugf("Writing %v", msgContext)
 	if c.shouldCompressMessage(msg) {
 		return c.writeCompressedMessage(msg)
 	}
@@ -955,7 +950,7 @@ func (c *rawConnection) internalClose(err error) {
 
 		<-c.dispatcherLoopStopped
 
-		c.receiver.Closed(c, err)
+		c.receiver.Closed(c.ID(), err)
 	})
 }
 
diff --git a/lib/protocol/protocol_test.go b/lib/protocol/protocol_test.go
index cfff22d77..a82cbb961 100644
--- a/lib/protocol/protocol_test.go
+++ b/lib/protocol/protocol_test.go
@@ -31,10 +31,10 @@ func TestPing(t *testing.T) {
 	ar, aw := io.Pipe()
 	br, bw := io.Pipe()
 
-	c0 := NewConnection(c0ID, ar, bw, testutils.NoopCloser{}, newTestModel(), new(mockedConnectionInfo), CompressionAlways).(wireFormatConnection).Connection.(*rawConnection)
+	c0 := getRawConnection(NewConnection(c0ID, ar, bw, testutils.NoopCloser{}, newTestModel(), new(mockedConnectionInfo), CompressionAlways, nil))
 	c0.Start()
 	defer closeAndWait(c0, ar, bw)
-	c1 := NewConnection(c1ID, br, aw, testutils.NoopCloser{}, newTestModel(), new(mockedConnectionInfo), CompressionAlways).(wireFormatConnection).Connection.(*rawConnection)
+	c1 := getRawConnection(NewConnection(c1ID, br, aw, testutils.NoopCloser{}, newTestModel(), new(mockedConnectionInfo), CompressionAlways, nil))
 	c1.Start()
 	defer closeAndWait(c1, ar, bw)
 	c0.ClusterConfig(ClusterConfig{})
@@ -57,10 +57,10 @@ func TestClose(t *testing.T) {
 	ar, aw := io.Pipe()
 	br, bw := io.Pipe()
 
-	c0 := NewConnection(c0ID, ar, bw, testutils.NoopCloser{}, m0, new(mockedConnectionInfo), CompressionAlways).(wireFormatConnection).Connection.(*rawConnection)
+	c0 := getRawConnection(NewConnection(c0ID, ar, bw, testutils.NoopCloser{}, m0, new(mockedConnectionInfo), CompressionAlways, nil))
 	c0.Start()
 	defer closeAndWait(c0, ar, bw)
-	c1 := NewConnection(c1ID, br, aw, testutils.NoopCloser{}, m1, new(mockedConnectionInfo), CompressionAlways)
+	c1 := NewConnection(c1ID, br, aw, testutils.NoopCloser{}, m1, new(mockedConnectionInfo), CompressionAlways, nil)
 	c1.Start()
 	defer closeAndWait(c1, ar, bw)
 	c0.ClusterConfig(ClusterConfig{})
@@ -102,7 +102,7 @@ func TestCloseOnBlockingSend(t *testing.T) {
 	m := newTestModel()
 
 	rw := testutils.NewBlockingRW()
-	c := NewConnection(c0ID, rw, rw, testutils.NoopCloser{}, m, new(mockedConnectionInfo), CompressionAlways).(wireFormatConnection).Connection.(*rawConnection)
+	c := getRawConnection(NewConnection(c0ID, rw, rw, testutils.NoopCloser{}, m, new(mockedConnectionInfo), CompressionAlways, nil))
 	c.Start()
 	defer closeAndWait(c, rw)
 
@@ -153,10 +153,10 @@ func TestCloseRace(t *testing.T) {
 	ar, aw := io.Pipe()
 	br, bw := io.Pipe()
 
-	c0 := NewConnection(c0ID, ar, bw, testutils.NoopCloser{}, m0, new(mockedConnectionInfo), CompressionNever).(wireFormatConnection).Connection.(*rawConnection)
+	c0 := getRawConnection(NewConnection(c0ID, ar, bw, testutils.NoopCloser{}, m0, new(mockedConnectionInfo), CompressionNever, nil))
 	c0.Start()
 	defer closeAndWait(c0, ar, bw)
-	c1 := NewConnection(c1ID, br, aw, testutils.NoopCloser{}, m1, new(mockedConnectionInfo), CompressionNever)
+	c1 := NewConnection(c1ID, br, aw, testutils.NoopCloser{}, m1, new(mockedConnectionInfo), CompressionNever, nil)
 	c1.Start()
 	defer closeAndWait(c1, ar, bw)
 	c0.ClusterConfig(ClusterConfig{})
@@ -193,7 +193,7 @@ func TestClusterConfigFirst(t *testing.T) {
 	m := newTestModel()
 
 	rw := testutils.NewBlockingRW()
-	c := NewConnection(c0ID, rw, &testutils.NoopRW{}, testutils.NoopCloser{}, m, new(mockedConnectionInfo), CompressionAlways).(wireFormatConnection).Connection.(*rawConnection)
+	c := getRawConnection(NewConnection(c0ID, rw, &testutils.NoopRW{}, testutils.NoopCloser{}, m, new(mockedConnectionInfo), CompressionAlways, nil))
 	c.Start()
 	defer closeAndWait(c, rw)
 
@@ -245,7 +245,7 @@ func TestCloseTimeout(t *testing.T) {
 	m := newTestModel()
 
 	rw := testutils.NewBlockingRW()
-	c := NewConnection(c0ID, rw, rw, testutils.NoopCloser{}, m, new(mockedConnectionInfo), CompressionAlways).(wireFormatConnection).Connection.(*rawConnection)
+	c := getRawConnection(NewConnection(c0ID, rw, rw, testutils.NoopCloser{}, m, new(mockedConnectionInfo), CompressionAlways, nil))
 	c.Start()
 	defer closeAndWait(c, rw)
 
@@ -865,7 +865,7 @@ func TestClusterConfigAfterClose(t *testing.T) {
 	m := newTestModel()
 
 	rw := testutils.NewBlockingRW()
-	c := NewConnection(c0ID, rw, rw, testutils.NoopCloser{}, m, new(mockedConnectionInfo), CompressionAlways).(wireFormatConnection).Connection.(*rawConnection)
+	c := getRawConnection(NewConnection(c0ID, rw, rw, testutils.NoopCloser{}, m, new(mockedConnectionInfo), CompressionAlways, nil))
 	c.Start()
 	defer closeAndWait(c, rw)
 
@@ -889,7 +889,7 @@ func TestDispatcherToCloseDeadlock(t *testing.T) {
 	// the model callbacks (ClusterConfig).
 	m := newTestModel()
 	rw := testutils.NewBlockingRW()
-	c := NewConnection(c0ID, rw, &testutils.NoopRW{}, testutils.NoopCloser{}, m, new(mockedConnectionInfo), CompressionAlways).(wireFormatConnection).Connection.(*rawConnection)
+	c := getRawConnection(NewConnection(c0ID, rw, &testutils.NoopRW{}, testutils.NoopCloser{}, m, new(mockedConnectionInfo), CompressionAlways, nil))
 	m.ccFn = func(devID DeviceID, cc ClusterConfig) {
 		c.Close(errManual)
 	}
@@ -962,17 +962,28 @@ func TestIndexIDString(t *testing.T) {
 	}
 }
 
-func closeAndWait(c Connection, closers ...io.Closer) {
+func closeAndWait(c interface{}, closers ...io.Closer) {
 	for _, closer := range closers {
 		closer.Close()
 	}
 	var raw *rawConnection
 	switch i := c.(type) {
-	case wireFormatConnection:
-		raw = i.Connection.(*rawConnection)
 	case *rawConnection:
 		raw = i
+	default:
+		raw = getRawConnection(c.(Connection))
 	}
 	raw.internalClose(ErrClosed)
 	raw.loopWG.Wait()
 }
+
+func getRawConnection(c Connection) *rawConnection {
+	var raw *rawConnection
+	switch i := c.(type) {
+	case wireFormatConnection:
+		raw = i.Connection.(encryptedConnection).conn
+	case encryptedConnection:
+		raw = i.conn
+	}
+	return raw
+}
-- 
GitLab