From 084a2b00268ed561f59ac19b1b6660a3c58573b3 Mon Sep 17 00:00:00 2001
From: 6543 <6543@obermui.de>
Date: Wed, 26 Feb 2020 07:32:22 +0100
Subject: [PATCH] Code Refactor of IssueWatch related things (#10401)

* refactor

* optimize

* remove Iretating function
LoadWatchUsers do not load Users into IW object and it is used only in api ... so move this logic

* remove unessesary

* Apply suggestions from code review

Thx

Co-Authored-By: guillep2k <18600385+guillep2k@users.noreply.github.com>

* make Tests more robust

* fix rebase

* restart CI

* CI no dont hit sqlites deadlock

Co-authored-by: guillep2k <18600385+guillep2k@users.noreply.github.com>
---
 integrations/repofiles_update_test.go     | 13 +++--
 models/issue_watch.go                     | 40 +++-----------
 models/notification.go                    | 66 +++++++++++------------
 models/repo_watch.go                      |  6 ++-
 models/user.go                            |  2 +-
 modules/git/repo_branch.go                |  3 ++
 modules/test/context_tests.go             |  7 ++-
 routers/api/v1/repo/issue_subscription.go |  9 +++-
 8 files changed, 66 insertions(+), 80 deletions(-)

diff --git a/integrations/repofiles_update_test.go b/integrations/repofiles_update_test.go
index a7beec4955..c422483bf8 100644
--- a/integrations/repofiles_update_test.go
+++ b/integrations/repofiles_update_test.go
@@ -207,11 +207,14 @@ func TestCreateOrUpdateRepoFileForCreate(t *testing.T) {
 
 		commitID, _ := gitRepo.GetBranchCommitID(opts.NewBranch)
 		expectedFileResponse := getExpectedFileResponseForRepofilesCreate(commitID)
-		assert.EqualValues(t, expectedFileResponse.Content, fileResponse.Content)
-		assert.EqualValues(t, expectedFileResponse.Commit.SHA, fileResponse.Commit.SHA)
-		assert.EqualValues(t, expectedFileResponse.Commit.HTMLURL, fileResponse.Commit.HTMLURL)
-		assert.EqualValues(t, expectedFileResponse.Commit.Author.Email, fileResponse.Commit.Author.Email)
-		assert.EqualValues(t, expectedFileResponse.Commit.Author.Name, fileResponse.Commit.Author.Name)
+		assert.NotNil(t, expectedFileResponse)
+		if expectedFileResponse != nil {
+			assert.EqualValues(t, expectedFileResponse.Content, fileResponse.Content)
+			assert.EqualValues(t, expectedFileResponse.Commit.SHA, fileResponse.Commit.SHA)
+			assert.EqualValues(t, expectedFileResponse.Commit.HTMLURL, fileResponse.Commit.HTMLURL)
+			assert.EqualValues(t, expectedFileResponse.Commit.Author.Email, fileResponse.Commit.Author.Email)
+			assert.EqualValues(t, expectedFileResponse.Commit.Author.Name, fileResponse.Commit.Author.Name)
+		}
 	})
 }
 
diff --git a/models/issue_watch.go b/models/issue_watch.go
index c4732d784e..9046e4d2f7 100644
--- a/models/issue_watch.go
+++ b/models/issue_watch.go
@@ -68,10 +68,14 @@ func getIssueWatch(e Engine, userID, issueID int64) (iw *IssueWatch, exists bool
 // but avoids joining with `user` for performance reasons
 // User permissions must be verified elsewhere if required
 func GetIssueWatchersIDs(issueID int64) ([]int64, error) {
+	return getIssueWatchersIDs(x, issueID, true)
+}
+
+func getIssueWatchersIDs(e Engine, issueID int64, watching bool) ([]int64, error) {
 	ids := make([]int64, 0, 64)
-	return ids, x.Table("issue_watch").
+	return ids, e.Table("issue_watch").
 		Where("issue_id=?", issueID).
-		And("is_watching = ?", true).
+		And("is_watching = ?", watching).
 		Select("user_id").
 		Find(&ids)
 }
@@ -99,39 +103,9 @@ func getIssueWatchers(e Engine, issueID int64, listOptions ListOptions) (IssueWa
 }
 
 func removeIssueWatchersByRepoID(e Engine, userID int64, repoID int64) error {
-	iw := &IssueWatch{
-		IsWatching: false,
-	}
 	_, err := e.
 		Join("INNER", "issue", "`issue`.id = `issue_watch`.issue_id AND `issue`.repo_id = ?", repoID).
-		Cols("is_watching", "updated_unix").
 		Where("`issue_watch`.user_id = ?", userID).
-		Update(iw)
+		Delete(new(IssueWatch))
 	return err
 }
-
-// LoadWatchUsers return watching users
-func (iwl IssueWatchList) LoadWatchUsers() (users UserList, err error) {
-	return iwl.loadWatchUsers(x)
-}
-
-func (iwl IssueWatchList) loadWatchUsers(e Engine) (users UserList, err error) {
-	if len(iwl) == 0 {
-		return []*User{}, nil
-	}
-
-	var userIDs = make([]int64, 0, len(iwl))
-	for _, iw := range iwl {
-		if iw.IsWatching {
-			userIDs = append(userIDs, iw.UserID)
-		}
-	}
-
-	if len(userIDs) == 0 {
-		return []*User{}, nil
-	}
-
-	err = e.In("id", userIDs).Find(&users)
-
-	return
-}
diff --git a/models/notification.go b/models/notification.go
index e7217a6e04..c52d6c557a 100644
--- a/models/notification.go
+++ b/models/notification.go
@@ -133,55 +133,42 @@ func CreateOrUpdateIssueNotifications(issueID, commentID int64, notificationAuth
 }
 
 func createOrUpdateIssueNotifications(e Engine, issueID, commentID int64, notificationAuthorID int64) error {
-	issueWatches, err := getIssueWatchers(e, issueID, ListOptions{})
+	// init
+	toNotify := make(map[int64]struct{}, 32)
+	notifications, err := getNotificationsByIssueID(e, issueID)
 	if err != nil {
 		return err
 	}
-
 	issue, err := getIssueByID(e, issueID)
 	if err != nil {
 		return err
 	}
 
-	watches, err := getWatchers(e, issue.RepoID)
+	issueWatches, err := getIssueWatchersIDs(e, issueID, true)
 	if err != nil {
 		return err
 	}
+	for _, id := range issueWatches {
+		toNotify[id] = struct{}{}
+	}
 
-	notifications, err := getNotificationsByIssueID(e, issueID)
+	repoWatches, err := getRepoWatchersIDs(e, issue.RepoID)
 	if err != nil {
 		return err
 	}
-
-	alreadyNotified := make(map[int64]struct{}, len(issueWatches)+len(watches))
-
-	notifyUser := func(userID int64) error {
-		// do not send notification for the own issuer/commenter
-		if userID == notificationAuthorID {
-			return nil
-		}
-
-		if _, ok := alreadyNotified[userID]; ok {
-			return nil
-		}
-		alreadyNotified[userID] = struct{}{}
-
-		if notificationExists(notifications, issue.ID, userID) {
-			return updateIssueNotification(e, userID, issue.ID, commentID, notificationAuthorID)
-		}
-		return createIssueNotification(e, userID, issue, commentID, notificationAuthorID)
+	for _, id := range repoWatches {
+		toNotify[id] = struct{}{}
 	}
 
-	for _, issueWatch := range issueWatches {
-		// ignore if user unwatched the issue
-		if !issueWatch.IsWatching {
-			alreadyNotified[issueWatch.UserID] = struct{}{}
-			continue
-		}
-
-		if err := notifyUser(issueWatch.UserID); err != nil {
-			return err
-		}
+	// dont notify user who cause notification
+	delete(toNotify, notificationAuthorID)
+	// explicit unwatch on issue
+	issueUnWatches, err := getIssueWatchersIDs(e, issueID, false)
+	if err != nil {
+		return err
+	}
+	for _, id := range issueUnWatches {
+		delete(toNotify, id)
 	}
 
 	err = issue.loadRepo(e)
@@ -189,16 +176,23 @@ func createOrUpdateIssueNotifications(e Engine, issueID, commentID int64, notifi
 		return err
 	}
 
-	for _, watch := range watches {
+	// notify
+	for userID := range toNotify {
 		issue.Repo.Units = nil
-		if issue.IsPull && !issue.Repo.checkUnitUser(e, watch.UserID, false, UnitTypePullRequests) {
+		if issue.IsPull && !issue.Repo.checkUnitUser(e, userID, false, UnitTypePullRequests) {
 			continue
 		}
-		if !issue.IsPull && !issue.Repo.checkUnitUser(e, watch.UserID, false, UnitTypeIssues) {
+		if !issue.IsPull && !issue.Repo.checkUnitUser(e, userID, false, UnitTypeIssues) {
 			continue
 		}
 
-		if err := notifyUser(watch.UserID); err != nil {
+		if notificationExists(notifications, issue.ID, userID) {
+			if err = updateIssueNotification(e, userID, issue.ID, commentID, notificationAuthorID); err != nil {
+				return err
+			}
+			continue
+		}
+		if err = createIssueNotification(e, userID, issue, commentID, notificationAuthorID); err != nil {
 			return err
 		}
 	}
diff --git a/models/repo_watch.go b/models/repo_watch.go
index a9d56eff03..11cfa88918 100644
--- a/models/repo_watch.go
+++ b/models/repo_watch.go
@@ -144,8 +144,12 @@ func GetWatchers(repoID int64) ([]*Watch, error) {
 // but avoids joining with `user` for performance reasons
 // User permissions must be verified elsewhere if required
 func GetRepoWatchersIDs(repoID int64) ([]int64, error) {
+	return getRepoWatchersIDs(x, repoID)
+}
+
+func getRepoWatchersIDs(e Engine, repoID int64) ([]int64, error) {
 	ids := make([]int64, 0, 64)
-	return ids, x.Table("watch").
+	return ids, e.Table("watch").
 		Where("watch.repo_id=?", repoID).
 		And("watch.mode<>?", RepoWatchModeDont).
 		Select("user_id").
diff --git a/models/user.go b/models/user.go
index bf59c1240b..8be15ba6df 100644
--- a/models/user.go
+++ b/models/user.go
@@ -1409,7 +1409,7 @@ func GetUserNamesByIDs(ids []int64) ([]string, error) {
 }
 
 // GetUsersByIDs returns all resolved users from a list of Ids.
-func GetUsersByIDs(ids []int64) ([]*User, error) {
+func GetUsersByIDs(ids []int64) (UserList, error) {
 	ous := make([]*User, 0, len(ids))
 	if len(ids) == 0 {
 		return ous, nil
diff --git a/modules/git/repo_branch.go b/modules/git/repo_branch.go
index e79bab76a6..3d0e6497ed 100644
--- a/modules/git/repo_branch.go
+++ b/modules/git/repo_branch.go
@@ -48,6 +48,9 @@ type Branch struct {
 
 // GetHEADBranch returns corresponding branch of HEAD.
 func (repo *Repository) GetHEADBranch() (*Branch, error) {
+	if repo == nil {
+		return nil, fmt.Errorf("nil repo")
+	}
 	stdout, err := NewCommand("symbolic-ref", "HEAD").RunInDir(repo.Path)
 	if err != nil {
 		return nil, err
diff --git a/modules/test/context_tests.go b/modules/test/context_tests.go
index cf9c5fbc54..f9f0ec5d42 100644
--- a/modules/test/context_tests.go
+++ b/modules/test/context_tests.go
@@ -58,8 +58,11 @@ func LoadRepoCommit(t *testing.T, ctx *context.Context) {
 	defer gitRepo.Close()
 	branch, err := gitRepo.GetHEADBranch()
 	assert.NoError(t, err)
-	ctx.Repo.Commit, err = gitRepo.GetBranchCommit(branch.Name)
-	assert.NoError(t, err)
+	assert.NotNil(t, branch)
+	if branch != nil {
+		ctx.Repo.Commit, err = gitRepo.GetBranchCommit(branch.Name)
+		assert.NoError(t, err)
+	}
 }
 
 // LoadUser load a user into a test context.
diff --git a/routers/api/v1/repo/issue_subscription.go b/routers/api/v1/repo/issue_subscription.go
index 274da966fd..0406edd207 100644
--- a/routers/api/v1/repo/issue_subscription.go
+++ b/routers/api/v1/repo/issue_subscription.go
@@ -190,9 +190,14 @@ func GetIssueSubscribers(ctx *context.APIContext) {
 		return
 	}
 
-	users, err := iwl.LoadWatchUsers()
+	var userIDs = make([]int64, 0, len(iwl))
+	for _, iw := range iwl {
+		userIDs = append(userIDs, iw.UserID)
+	}
+
+	users, err := models.GetUsersByIDs(userIDs)
 	if err != nil {
-		ctx.Error(http.StatusInternalServerError, "LoadWatchUsers", err)
+		ctx.Error(http.StatusInternalServerError, "GetUsersByIDs", err)
 		return
 	}