From 11e17c8a868ac08c95bffaab3a80d9935af27645 Mon Sep 17 00:00:00 2001 From: Marvin Zhang Date: Wed, 13 Nov 2024 16:43:33 +0800 Subject: [PATCH] fix: user controller issues --- core/controllers/router.go | 7 +- core/controllers/user.go | 17 +- core/controllers/user_test.go | 306 ++++++++++++++++++++++++++++++++-- core/user/service.go | 29 ++-- 4 files changed, 329 insertions(+), 30 deletions(-) diff --git a/core/controllers/router.go b/core/controllers/router.go index 409a9bee..0836043b 100644 --- a/core/controllers/router.go +++ b/core/controllers/router.go @@ -231,7 +231,7 @@ func InitRoutes(app *gin.Engine) (err error) { }...)) RegisterController(groups.AuthGroup, "/users", NewController[models.User]([]Action{ { - Method: http.MethodPost, + Method: http.MethodGet, Path: "/:id", HandlerFunc: GetUserById, }, @@ -245,6 +245,11 @@ func InitRoutes(app *gin.Engine) (err error) { Path: "", HandlerFunc: PostUser, }, + { + Method: http.MethodPut, + Path: "/:id", + HandlerFunc: PutUserById, + }, { Method: http.MethodPost, Path: "/:id/change-password", diff --git a/core/controllers/user.go b/core/controllers/user.go index a67b06ba..37a41266 100644 --- a/core/controllers/user.go +++ b/core/controllers/user.go @@ -2,6 +2,9 @@ package controllers import ( "errors" + "fmt" + "regexp" + "github.com/crawlab-team/crawlab/core/models/models" "github.com/crawlab-team/crawlab/core/models/service" "github.com/crawlab-team/crawlab/core/utils" @@ -97,6 +100,16 @@ func PostUser(c *gin.Context) { HandleErrorBadRequest(c, err) return } + + // Validate email format + if payload.Email != "" { + emailRegex := regexp.MustCompile(`^[a-zA-Z0-9._%+\-]+@[a-zA-Z0-9.\-]+\.[a-zA-Z]{2,}$`) + if !emailRegex.MatchString(payload.Email) { + HandleErrorBadRequest(c, fmt.Errorf("invalid email format")) + return + } + } + if !payload.RoleId.IsZero() { _, err := service.NewModelService[models.Role]().GetById(payload.RoleId) if err != nil { @@ -221,9 +234,9 @@ func putUser(userId primitive.ObjectID, c *gin.Context) { // update user user.SetUpdated(u.Id) if user.Id.IsZero() { - user.Id = u.Id + user.Id = userId } - if err := modelSvc.ReplaceById(u.Id, user); err != nil { + if err := modelSvc.ReplaceById(userId, user); err != nil { HandleErrorInternalServerError(c, err) return } diff --git a/core/controllers/user_test.go b/core/controllers/user_test.go index c6c8a1ff..ebe6fa1f 100644 --- a/core/controllers/user_test.go +++ b/core/controllers/user_test.go @@ -5,14 +5,202 @@ import ( "github.com/crawlab-team/crawlab/core/middlewares" "github.com/crawlab-team/crawlab/core/models/models" "github.com/crawlab-team/crawlab/core/models/service" + "github.com/crawlab-team/crawlab/core/user" "github.com/gin-gonic/gin" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.mongodb.org/mongo-driver/bson" "net/http" "net/http/httptest" "strings" "testing" + + "github.com/crawlab-team/crawlab/core/utils" ) +func TestGetUserById_Success(t *testing.T) { + SetupTestDB() + defer CleanupTestDB() + + // Create test user with required fields + modelSvc := service.NewModelService[models.User]() + u := models.User{ + Username: "testuser", + Email: "test@example.com", + Password: utils.EncryptMd5("testpassword"), // Add password + } + id, err := modelSvc.InsertOne(u) + require.Nil(t, err) + u.SetId(id) + + router := gin.Default() + router.Use(middlewares.AuthorizationMiddleware()) + router.GET("/users/:id", controllers.GetUserById) + + // Test valid ID + req, err := http.NewRequest(http.MethodGet, "/users/"+id.Hex(), nil) + assert.Nil(t, err) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", TestToken) + + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code) + + // Test invalid ID format + req, err = http.NewRequest(http.MethodGet, "/users/invalid-id", nil) + assert.Nil(t, err) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", TestToken) + + w = httptest.NewRecorder() + router.ServeHTTP(w, req) + assert.Equal(t, http.StatusBadRequest, w.Code) +} + +func TestGetUserList_Success(t *testing.T) { + SetupTestDB() + defer CleanupTestDB() + + modelSvc := service.NewModelService[models.User]() + + // Create test users with required fields + users := []models.User{ + {Username: "user1", Email: "user1@example.com", Password: utils.EncryptMd5("password1")}, + {Username: "user2", Email: "user2@example.com", Password: utils.EncryptMd5("password2")}, + {Username: "user3", Email: "user3@example.com", Password: utils.EncryptMd5("password3")}, + } + + for _, u := range users { + _, err := modelSvc.InsertOne(u) + assert.Nil(t, err) + } + + router := gin.Default() + router.Use(middlewares.AuthorizationMiddleware()) + router.GET("/users", controllers.GetUserList) + + // Test default pagination + req, err := http.NewRequest(http.MethodGet, "/users", nil) + assert.Nil(t, err) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", TestToken) + + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code) + + // Test with pagination parameters + req, err = http.NewRequest(http.MethodGet, "/users?page=1&size=2", nil) + assert.Nil(t, err) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", TestToken) + + w = httptest.NewRecorder() + router.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code) +} + +func TestPostUser_Success(t *testing.T) { + SetupTestDB() + defer CleanupTestDB() + + router := gin.Default() + router.Use(middlewares.AuthorizationMiddleware()) + router.POST("/users", controllers.PostUser) + + // Test creating a new user with valid data + reqBody := strings.NewReader(`{ + "username": "newuser", + "password": "password123", + "email": "newuser@example.com" + }`) + req, err := http.NewRequest(http.MethodPost, "/users", reqBody) + assert.Nil(t, err) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", TestToken) + + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + + // Verify user was created + modelSvc := service.NewModelService[models.User]() + u, err := modelSvc.GetOne(bson.M{"username": "newuser"}, nil) + assert.Nil(t, err) + assert.Equal(t, "newuser", u.Username) + assert.Equal(t, "newuser@example.com", u.Email) + + // Test creating a user with invalid data + reqBody = strings.NewReader(`{ + "username": "", + "password": "", + "email": "invalid-email" + }`) + req, err = http.NewRequest(http.MethodPost, "/users", reqBody) + assert.Nil(t, err) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", TestToken) + + w = httptest.NewRecorder() + router.ServeHTTP(w, req) + assert.Equal(t, http.StatusBadRequest, w.Code) +} + +func TestPutUserById_Success(t *testing.T) { + SetupTestDB() + defer CleanupTestDB() + + modelSvc := service.NewModelService[models.User]() + u := models.User{} + id, err := modelSvc.InsertOne(u) + require.Nil(t, err) + u.SetId(id) + + router := gin.Default() + router.Use(middlewares.AuthorizationMiddleware()) + router.PUT("/users/:id", controllers.PutUserById) + + // Test case 1: Regular user update + reqBody := strings.NewReader(`{ + "id":"` + id.Hex() + `", + "username":"newUsername", + "email":"newEmail@test.com" + }`) + req, _ := http.NewRequest(http.MethodPut, "/users/"+id.Hex(), reqBody) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", TestToken) + + // Make request + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code) + + // Test case 2: Root admin user update (should not change username) + u.RootAdmin = true + err = modelSvc.ReplaceById(id, u) + assert.Nil(t, err) + + reqBody = strings.NewReader(`{ + "id":"` + id.Hex() + `", + "username":"attemptedNewUsername", + "email":"newEmail@test.com" + }`) + req, _ = http.NewRequest(http.MethodPut, "/users/"+id.Hex(), reqBody) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", TestToken) + + w = httptest.NewRecorder() + router.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code) + + // Verify username wasn't changed for root admin + updatedUser, err := modelSvc.GetById(id) + assert.Nil(t, err) + assert.NotEqual(t, "attemptedNewUsername", updatedUser.Username) +} + func TestPostUserChangePassword_Success(t *testing.T) { SetupTestDB() defer CleanupTestDB() @@ -20,14 +208,16 @@ func TestPostUserChangePassword_Success(t *testing.T) { modelSvc := service.NewModelService[models.User]() u := models.User{} id, err := modelSvc.InsertOne(u) - assert.Nil(t, err) + require.Nil(t, err) u.SetId(id) router := gin.Default() router.Use(middlewares.AuthorizationMiddleware()) router.POST("/users/:id/change-password", controllers.PostUserChangePassword) - password := "newPassword" + // Add validation for minimum password length + // Test case 1: Valid password + password := "validPassword123" reqBody := strings.NewReader(`{"password":"` + password + `"}`) req, _ := http.NewRequest(http.MethodPost, "/users/"+id.Hex()+"/change-password", reqBody) req.Header.Set("Content-Type", "application/json") @@ -35,8 +225,18 @@ func TestPostUserChangePassword_Success(t *testing.T) { w := httptest.NewRecorder() router.ServeHTTP(w, req) - assert.Equal(t, http.StatusOK, w.Code) + + // Test case 2: Password too short + shortPassword := "1234" + reqBody = strings.NewReader(`{"password":"` + shortPassword + `"}`) + req, _ = http.NewRequest(http.MethodPost, "/users/"+id.Hex()+"/change-password", reqBody) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", TestToken) + + w = httptest.NewRecorder() + router.ServeHTTP(w, req) + assert.Equal(t, http.StatusBadRequest, w.Code) } func TestGetUserMe_Success(t *testing.T) { @@ -46,7 +246,7 @@ func TestGetUserMe_Success(t *testing.T) { modelSvc := service.NewModelService[models.User]() u := models.User{} id, err := modelSvc.InsertOne(u) - assert.Nil(t, err) + require.Nil(t, err) u.SetId(id) router := gin.Default() @@ -63,27 +263,109 @@ func TestGetUserMe_Success(t *testing.T) { assert.Equal(t, http.StatusOK, w.Code) } -func TestPutUserById_Success(t *testing.T) { +func TestPutUserMe_Success(t *testing.T) { SetupTestDB() defer CleanupTestDB() + // Create test user with required fields modelSvc := service.NewModelService[models.User]() - u := models.User{} + u := models.User{ + Username: "originaluser", + Email: "original@example.com", + Password: utils.EncryptMd5("testpassword"), + } id, err := modelSvc.InsertOne(u) - assert.Nil(t, err) + require.Nil(t, err) u.SetId(id) + // Create token for user + userSvc, err := user.GetUserService() + require.Nil(t, err) + token, err := userSvc.MakeToken(&u) + require.Nil(t, err) + + // Create router router := gin.Default() router.Use(middlewares.AuthorizationMiddleware()) - router.PUT("/users/me", controllers.PutUserById) + router.PUT("/users/me", controllers.PutUserMe) - reqBody := strings.NewReader(`{"id":"` + id.Hex() + `","username":"newUsername","email":"newEmail@test.com"}`) - req, _ := http.NewRequest(http.MethodPut, "/users/me", reqBody) + // Test valid update + reqBody := strings.NewReader(`{ + "username": "updateduser", + "email": "updated@example.com" + }`) + req, err := http.NewRequest(http.MethodPut, "/users/me", reqBody) + assert.Nil(t, err) req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", TestToken) + req.Header.Set("Authorization", token) w := httptest.NewRecorder() router.ServeHTTP(w, req) - assert.Equal(t, http.StatusOK, w.Code) + + // Verify the update + updatedUser, err := modelSvc.GetById(id) + assert.Nil(t, err) + assert.Equal(t, "updateduser", updatedUser.Username) + assert.Equal(t, "updated@example.com", updatedUser.Email) + + // Verify password wasn't changed + assert.Equal(t, utils.EncryptMd5("testpassword"), updatedUser.Password) +} + +func TestPostUserMeChangePassword_Success(t *testing.T) { + SetupTestDB() + defer CleanupTestDB() + + // Create test user with initial password + modelSvc := service.NewModelService[models.User]() + u := models.User{ + Username: "testuser", + Password: utils.EncryptMd5("initialpassword"), + Email: "test@example.com", + } + id, err := modelSvc.InsertOne(u) + require.Nil(t, err) + u.SetId(id) + + // Create token for user + userSvc, err := user.GetUserService() + require.Nil(t, err) + token, err := userSvc.MakeToken(&u) + require.Nil(t, err) + + // Create router + router := gin.Default() + router.Use(middlewares.AuthorizationMiddleware()) + router.POST("/users/me/change-password", controllers.PostUserMeChangePassword) + + // Test valid password change + password := "newValidPassword123" + reqBody := strings.NewReader(`{"password":"` + password + `"}`) + req, err := http.NewRequest(http.MethodPost, "/users/me/change-password", reqBody) + assert.Nil(t, err) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", token) + + // Make request + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code) + + // Verify password was changed + updatedUser, err := modelSvc.GetById(id) + assert.Nil(t, err) + assert.Equal(t, utils.EncryptMd5(password), updatedUser.Password) + + // Test invalid password (too short) + shortPassword := "123" + reqBody = strings.NewReader(`{"password":"` + shortPassword + `"}`) + req, err = http.NewRequest(http.MethodPost, "/users/me/change-password", reqBody) + assert.Nil(t, err) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", TestToken) + + w = httptest.NewRecorder() + router.ServeHTTP(w, req) + assert.Equal(t, http.StatusBadRequest, w.Code) } diff --git a/core/user/service.go b/core/user/service.go index db32dad4..13a09d05 100644 --- a/core/user/service.go +++ b/core/user/service.go @@ -2,6 +2,7 @@ package user import ( errors2 "errors" + "fmt" "github.com/apex/log" "github.com/crawlab-team/crawlab/core/constants" "github.com/crawlab-team/crawlab/core/errors" @@ -181,37 +182,35 @@ func (svc *Service) makeToken(user *models.User) (tokenStr string, err error) { func (svc *Service) checkToken(tokenStr string) (user *models.User, err error) { token, err := jwt.Parse(tokenStr, svc.getSecretFunc()) if err != nil { - return + return nil, errors2.New("invalid token") } claim, ok := token.Claims.(jwt.MapClaims) if !ok { - err = errors.ErrorUserInvalidType - return + return nil, errors2.New("invalid type") } if !token.Valid { - err = errors.ErrorUserInvalidToken - return + return nil, errors2.New("invalid token") } id, err := primitive.ObjectIDFromHex(claim["id"].(string)) if err != nil { - return user, err + return nil, errors2.New("invalid token") } + fmt.Println(id) username := claim["username"].(string) - user, err = service.NewModelService[models.User]().GetById(id) + u, err := service.NewModelService[models.User]().GetById(id) if err != nil { - err = errors.ErrorUserNotExists - return + return nil, errors2.New("user not exists") + } + fmt.Println(fmt.Sprintf("%v", u)) + + if username != u.Username { + return nil, errors2.New("username mismatch") } - if username != user.Username { - err = errors.ErrorUserMismatch - return - } - - return + return u, nil } func (svc *Service) getSecretFunc() jwt.Keyfunc {