From 5749f50533225b5d38fed1ed86b1c893cc0466b5 Mon Sep 17 00:00:00 2001 From: Andreas Gerstmayr Date: Thu, 25 Nov 2021 18:49:52 +0100 Subject: [PATCH] notifications: use HMAC-SHA256 to generate password reset tokens * changes the time limit code generation function to use HMAC-SHA256 instead of SHA-1 * multiple new testcases diff --git a/pkg/services/notifications/codes.go b/pkg/services/notifications/codes.go index 32cd5dd7cd..72d33e3814 100644 --- a/pkg/services/notifications/codes.go +++ b/pkg/services/notifications/codes.go @@ -1,48 +1,53 @@ package notifications import ( - "crypto/sha1" // #nosec + "crypto/hmac" + "crypto/sha256" "encoding/hex" "fmt" + "strconv" "time" - "github.com/unknwon/com" - "github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/setting" ) -const timeLimitCodeLength = 12 + 6 + 40 +const timeLimitStartDateLength = 12 +const timeLimitMinutesLength = 6 +const timeLimitHmacLength = 64 +const timeLimitCodeLength = timeLimitStartDateLength + timeLimitMinutesLength + timeLimitHmacLength // create a time limit code -// code format: 12 length date time string + 6 minutes string + 40 sha1 encoded string -func createTimeLimitCode(data string, minutes int, startInf interface{}) (string, error) { +// code format: 12 length date time string + 6 minutes string + 64 HMAC-SHA256 encoded string +func createTimeLimitCode(payload string, minutes int, startStr string) (string, error) { format := "200601021504" var start, end time.Time - var startStr, endStr string + var endStr string - if startInf == nil { + if startStr == "" { // Use now time create code start = time.Now() startStr = start.Format(format) } else { // use start string create code - startStr = startInf.(string) - start, _ = time.ParseInLocation(format, startStr, time.Local) - startStr = start.Format(format) + var err error + start, err = time.ParseInLocation(format, startStr, time.Local) + if err != nil { + return "", err + } } end = start.Add(time.Minute * time.Duration(minutes)) endStr = end.Format(format) - // create sha1 encode string - sh := sha1.New() - if _, err := sh.Write([]byte(data + setting.SecretKey + startStr + endStr + - com.ToStr(minutes))); err != nil { - return "", err + // create HMAC-SHA256 encoded string + key := []byte(setting.SecretKey) + h := hmac.New(sha256.New, key) + if _, err := h.Write([]byte(payload + startStr + endStr)); err != nil { + return "", fmt.Errorf("cannot create hmac: %v", err) } - encoded := hex.EncodeToString(sh.Sum(nil)) + encoded := hex.EncodeToString(h.Sum(nil)) code := fmt.Sprintf("%s%06d%s", startStr, minutes, encoded) return code, nil @@ -50,29 +55,32 @@ func createTimeLimitCode(data string, minutes int, startInf interface{}) (string // verify time limit code func validateUserEmailCode(cfg *setting.Cfg, user *models.User, code string) (bool, error) { - if len(code) <= 18 { + if len(code) < timeLimitCodeLength { return false, nil } - minutes := cfg.EmailCodeValidMinutes code = code[:timeLimitCodeLength] // split code - start := code[:12] - lives := code[12:18] - if d, err := com.StrTo(lives).Int(); err == nil { - minutes = d + startStr := code[:timeLimitStartDateLength] + minutesStr := code[timeLimitStartDateLength : timeLimitStartDateLength+timeLimitMinutesLength] + minutes, err := strconv.Atoi(minutesStr) + if err != nil { + return false, fmt.Errorf("invalid time limit code: %v", err) } // right active code - data := com.ToStr(user.Id) + user.Email + user.Login + user.Password + user.Rands - retCode, err := createTimeLimitCode(data, minutes, start) + payload := strconv.FormatInt(user.Id, 10) + user.Email + user.Login + user.Password + user.Rands + expectedCode, err := createTimeLimitCode(payload, minutes, startStr) if err != nil { return false, err } - if retCode == code && minutes > 0 { + if hmac.Equal([]byte(code), []byte(expectedCode)) && minutes > 0 { // check time is expired or not - before, _ := time.ParseInLocation("200601021504", start, time.Local) + before, err := time.ParseInLocation("200601021504", startStr, time.Local) + if err != nil { + return false, err + } now := time.Now() if before.Add(time.Minute*time.Duration(minutes)).Unix() > now.Unix() { return true, nil @@ -93,15 +101,15 @@ func getLoginForEmailCode(code string) string { return string(b) } -func createUserEmailCode(cfg *setting.Cfg, u *models.User, startInf interface{}) (string, error) { +func createUserEmailCode(cfg *setting.Cfg, user *models.User, startStr string) (string, error) { minutes := cfg.EmailCodeValidMinutes - data := com.ToStr(u.Id) + u.Email + u.Login + u.Password + u.Rands - code, err := createTimeLimitCode(data, minutes, startInf) + payload := strconv.FormatInt(user.Id, 10) + user.Email + user.Login + user.Password + user.Rands + code, err := createTimeLimitCode(payload, minutes, startStr) if err != nil { return "", err } // add tail hex username - code += hex.EncodeToString([]byte(u.Login)) + code += hex.EncodeToString([]byte(user.Login)) return code, nil } diff --git a/pkg/services/notifications/codes_test.go b/pkg/services/notifications/codes_test.go index a314c8deca..be9b68ca69 100644 --- a/pkg/services/notifications/codes_test.go +++ b/pkg/services/notifications/codes_test.go @@ -1,7 +1,10 @@ package notifications import ( + "fmt" + "strconv" "testing" + "time" "github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/setting" @@ -9,18 +12,126 @@ import ( "github.com/stretchr/testify/require" ) +func TestTimeLimitCodes(t *testing.T) { + cfg := setting.NewCfg() + cfg.EmailCodeValidMinutes = 120 + user := &models.User{Id: 10, Email: "t@a.com", Login: "asd", Password: "1", Rands: "2"} + + format := "200601021504" + mailPayload := strconv.FormatInt(user.Id, 10) + user.Email + user.Login + user.Password + user.Rands + tenMinutesAgo := time.Now().Add(-time.Minute * 10) + + tests := []struct { + desc string + payload string + start time.Time + minutes int + valid bool + }{ + { + desc: "code generated 10 minutes ago, 5 minutes valid", + payload: mailPayload, + start: tenMinutesAgo, + minutes: 5, + valid: false, + }, + { + desc: "code generated 10 minutes ago, 9 minutes valid", + payload: mailPayload, + start: tenMinutesAgo, + minutes: 9, + valid: false, + }, + { + desc: "code generated 10 minutes ago, 10 minutes valid", + payload: mailPayload, + start: tenMinutesAgo, + minutes: 10, + // code was valid exactly 10 minutes since evaluating the tenMinutesAgo assignment + // by the time this test is run the code is already expired + valid: false, + }, + { + desc: "code generated 10 minutes ago, 11 minutes valid", + payload: mailPayload, + start: tenMinutesAgo, + minutes: 11, + valid: true, + }, + { + desc: "code generated 10 minutes ago, 20 minutes valid", + payload: mailPayload, + start: tenMinutesAgo, + minutes: 20, + valid: true, + }, + { + desc: "code generated 10 minutes ago, 20 minutes valid, tampered payload", + payload: mailPayload[:len(mailPayload)-1] + "x", + start: tenMinutesAgo, + minutes: 20, + valid: false, + }, + } + + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + code, err := createTimeLimitCode(test.payload, test.minutes, test.start.Format(format)) + require.NoError(t, err) + + isValid, err := validateUserEmailCode(cfg, user, code) + require.NoError(t, err) + require.Equal(t, test.valid, isValid) + }) + } + + t.Run("tampered minutes", func(t *testing.T) { + code, err := createTimeLimitCode(mailPayload, 5, tenMinutesAgo.Format(format)) + require.NoError(t, err) + + // code is expired + isValid, err := validateUserEmailCode(cfg, user, code) + require.NoError(t, err) + require.Equal(t, false, isValid) + + // let's try to extend the code by tampering the minutes + code = code[:12] + fmt.Sprintf("%06d", 20) + code[18:] + isValid, err = validateUserEmailCode(cfg, user, code) + require.NoError(t, err) + require.Equal(t, false, isValid) + }) + + t.Run("tampered start string", func(t *testing.T) { + code, err := createTimeLimitCode(mailPayload, 5, tenMinutesAgo.Format(format)) + require.NoError(t, err) + + // code is expired + isValid, err := validateUserEmailCode(cfg, user, code) + require.NoError(t, err) + require.Equal(t, false, isValid) + + // let's try to extend the code by tampering the start string + oneMinuteAgo := time.Now().Add(-time.Minute) + + code = oneMinuteAgo.Format(format) + code[12:] + isValid, err = validateUserEmailCode(cfg, user, code) + require.NoError(t, err) + require.Equal(t, false, isValid) + }) +} + func TestEmailCodes(t *testing.T) { t.Run("When generating code", func(t *testing.T) { cfg := setting.NewCfg() cfg.EmailCodeValidMinutes = 120 user := &models.User{Id: 10, Email: "t@a.com", Login: "asd", Password: "1", Rands: "2"} - code, err := createUserEmailCode(cfg, user, nil) + code, err := createUserEmailCode(cfg, user, "") require.NoError(t, err) t.Run("getLoginForCode should return login", func(t *testing.T) { login := getLoginForEmailCode(code) - require.Equal(t, login, "asd") + require.Equal(t, "asd", login) }) t.Run("Can verify valid code", func(t *testing.T) { @@ -29,7 +140,7 @@ func TestEmailCodes(t *testing.T) { require.True(t, isValid) }) - t.Run("Cannot verify in-valid code", func(t *testing.T) { + t.Run("Cannot verify invalid code", func(t *testing.T) { code = "ASD" isValid, err := validateUserEmailCode(cfg, user, code) require.NoError(t, err) diff --git a/pkg/services/notifications/notifications.go b/pkg/services/notifications/notifications.go index 84a0d42cb6..52facd0992 100644 --- a/pkg/services/notifications/notifications.go +++ b/pkg/services/notifications/notifications.go @@ -168,7 +168,7 @@ func (ns *NotificationService) SendEmailCommandHandler(ctx context.Context, cmd } func (ns *NotificationService) SendResetPasswordEmail(ctx context.Context, cmd *models.SendResetPasswordEmailCommand) error { - code, err := createUserEmailCode(ns.Cfg, cmd.User, nil) + code, err := createUserEmailCode(ns.Cfg, cmd.User, "") if err != nil { return err } diff --git a/pkg/services/notifications/notifications_test.go b/pkg/services/notifications/notifications_test.go index 71970e20a0..6f4b318fe0 100644 --- a/pkg/services/notifications/notifications_test.go +++ b/pkg/services/notifications/notifications_test.go @@ -2,6 +2,7 @@ package notifications import ( "context" + "regexp" "testing" "github.com/grafana/grafana/pkg/bus" @@ -185,7 +186,8 @@ func TestSendEmailAsync(t *testing.T) { t.Run("When sending reset email password", func(t *testing.T) { sut, _ := createSut(t, bus) - err := sut.SendResetPasswordEmail(context.Background(), &models.SendResetPasswordEmailCommand{User: &models.User{Email: "asd@asd.com"}}) + user := models.User{Email: "asd@asd.com", Login: "asd@asd.com"} + err := sut.SendResetPasswordEmail(context.Background(), &models.SendResetPasswordEmailCommand{User: &user}) require.NoError(t, err) sentMsg := <-sut.mailQueue @@ -194,6 +196,21 @@ func TestSendEmailAsync(t *testing.T) { assert.Equal(t, "Reset your Grafana password - asd@asd.com", sentMsg.Subject) assert.NotContains(t, sentMsg.Body["text/html"], "Subject") assert.NotContains(t, sentMsg.Body["text/plain"], "Subject") + + // find code in mail + r, _ := regexp.Compile(`code=(\w+)`) + match := r.FindString(sentMsg.Body["text/plain"]) + code := match[len("code="):] + + // verify code + query := models.ValidateResetPasswordCodeQuery{Code: code} + getUserByLogin := func(ctx context.Context, login string) (*models.User, error) { + query := models.GetUserByLoginQuery{LoginOrEmail: login} + query.Result = &user + return query.Result, nil + } + err = sut.ValidateResetPasswordCode(context.Background(), &query, getUserByLogin) + require.NoError(t, err) }) t.Run("When SMTP disabled in configuration", func(t *testing.T) {