diff --git a/biz/handler.go b/biz/handler.go index 0f116c8..890372b 100644 --- a/biz/handler.go +++ b/biz/handler.go @@ -17,13 +17,13 @@ func Router() *gin.Engine { userRouter := v1.Group("/user") { - userRouter.POST("/create", middleware.ValidateAppleAppToken(), common.Wrapper(user.CreateUser)) - userRouter.GET("/info", middleware.ValidateAppleAppToken(), common.Wrapper(user.GetUserInfo)) + userRouter.POST("/login-with-apple", middleware.ValidateAppleAppLoginCode(), common.Wrapper(user.LoginWithApple)) + userRouter.GET("/info", middleware.ValidateToken(), common.Wrapper(user.GetUserInfo)) } if config.IsDebug() { // for debug - v1.GET("/ping", middleware.ValidateAppleAppToken(), func(ctx *gin.Context) { ctx.JSON(200, gin.H{"message": "pong"}) }) + v1.GET("/ping", middleware.ValidateToken(), func(ctx *gin.Context) { ctx.JSON(200, gin.H{"message": "pong"}) }) } return r } diff --git a/biz/user/create.go b/biz/user/create.go index b608644..278dcbf 100644 --- a/biz/user/create.go +++ b/biz/user/create.go @@ -2,23 +2,67 @@ package user import ( "context" + "time" "github.com/nose7en/ToyBoomServer/common" + "github.com/nose7en/ToyBoomServer/config" "github.com/nose7en/ToyBoomServer/dao" "github.com/nose7en/ToyBoomServer/defs" - "github.com/nose7en/ToyBoomServer/models" + "github.com/nose7en/ToyBoomServer/utils" ) -func CreateUser(c context.Context, req *defs.CommonRequest) (*defs.CommonResponse, error) { +func LoginWithApple(c context.Context, req *defs.CommonRequest) (*defs.GetUserAuthTokenResponse, error) { userInfo := common.GetUser(c) - newUser := &models.User{} - newUser.FillWithUserInfo(userInfo) - if err := dao.NewMutation().CreateUser(newUser); err != nil { + if userInfo.GetUserID() > 0 { + userInfo, err := dao.NewQuery().GetUserByID(userInfo.GetUserID()) + if err != nil { + common.Logger(c).WithError(err).Errorf("failed to get user info") + return nil, err + } + + newUserToken, err := newUserToken(userInfo) + if err != nil { + common.Logger(c).WithError(err).Errorf("failed to new user token") + return nil, err + } + + return &defs.GetUserAuthTokenResponse{ + Status: &defs.Status{Code: defs.RespCode_SUCCESS, Message: "user exists"}, + Token: newUserToken, + }, nil + } + + newUser, err := dao.NewMutation().FirstOrCreateUser(userInfo) + if err != nil { + common.Logger(c).WithError(err).Errorf("failed to create user") return nil, err } - return &defs.CommonResponse{ + newUserToken, err := newUserToken(newUser) + if err != nil { + common.Logger(c).WithError(err).Errorf("failed to new user token") + return nil, err + } + + common.Logger(c).Infof("create user success, user info record is %+v", newUser) + return &defs.GetUserAuthTokenResponse{ Status: &defs.Status{Code: defs.RespCode_SUCCESS, Message: defs.RespMessage_SUCCESS}, + Token: newUserToken, }, nil } + +func newUserToken(userInfo defs.UserGettable) (string, error) { + token, err := utils.GetJwtTokenFromMap( + config.GetSettings().JWTConfig.Secret, + time.Now().Unix(), + config.GetSettings().JWTConfig.ExpireSec, + userInfo.ToMap(), + ) + + if err != nil { + return "", err + } + + return token, nil +} diff --git a/common/context.go b/common/context.go index 6340ce7..b57298d 100644 --- a/common/context.go +++ b/common/context.go @@ -4,6 +4,7 @@ import ( "context" "github.com/nose7en/ToyBoomServer/defs" + "github.com/spf13/cast" ) func GetUser(c context.Context) defs.UserGettable { @@ -19,3 +20,17 @@ func GetUser(c context.Context) defs.UserGettable { return u } + +func GetToken(c context.Context) string { + return GetStrValue(c, TokenKey) +} + +func GetStrValue(c context.Context, key string) string { + val := c.Value(key) + return cast.ToString(val) +} + +func GetStrValueE(c context.Context, key string) (string, error) { + val := c.Value(key) + return cast.ToStringE(val) +} diff --git a/config/settings.go b/config/settings.go index 1581980..d52fea7 100644 --- a/config/settings.go +++ b/config/settings.go @@ -33,7 +33,8 @@ type RedisConf struct { } type JWTConfig struct { - Secret string `mapstructure:"secret"` + Secret string `mapstructure:"secret"` + ExpireSec int64 `mapstructure:"expire_sec"` } type AppleConf struct { @@ -62,6 +63,7 @@ func fillDefaultSettings() { viper.SetDefault("debug", false) viper.SetDefault("db.type", "sqlite") viper.SetDefault("db.dsn", "toyboom.db") + viper.SetDefault("jwt.expire_sec", 86400*30) // 30 days } func setConfigParams() { diff --git a/dao/interface.go b/dao/interface.go index 921699f..3bd6b02 100644 --- a/dao/interface.go +++ b/dao/interface.go @@ -2,26 +2,30 @@ package dao import ( "github.com/nose7en/ToyBoomServer/defs" + "github.com/nose7en/ToyBoomServer/storage" + "gorm.io/gorm" ) type Query interface { GetUserByAppleUserID(appleUserID string) (defs.UserGettable, error) + GetUserByID(userID int64) (defs.UserGettable, error) } type Mutation interface { - CreateUser(user defs.UserGettable) error + FirstOrCreateUser(user defs.UserGettable) (defs.UserGettable, error) } var _ Query = (*queryImpl)(nil) var _ Mutation = (*mutationImpl)(nil) -type queryImpl struct{} -type mutationImpl struct{} +type queryImpl struct{ db *gorm.DB } + +type mutationImpl struct{ db *gorm.DB } func NewQuery() Query { - return &queryImpl{} + return &queryImpl{db: storage.GetDBManager().GetDefaultDB()} } func NewMutation() Mutation { - return &mutationImpl{} + return &mutationImpl{db: storage.GetDBManager().GetDefaultDB()} } diff --git a/dao/user.go b/dao/user.go index 3fcb5c8..772d196 100644 --- a/dao/user.go +++ b/dao/user.go @@ -1,13 +1,33 @@ package dao -import "github.com/nose7en/ToyBoomServer/defs" +import ( + "github.com/nose7en/ToyBoomServer/defs" + "github.com/nose7en/ToyBoomServer/models" + "gorm.io/gorm" +) func (q *queryImpl) GetUserByAppleUserID(appleUserID string) (defs.UserGettable, error) { - - return nil, nil + user := &models.User{} + if result := q.db.Where(models.User{AppleUserID: appleUserID}).First(user); result.Error != nil { + return nil, result.Error + } + return user, nil } -func (m *mutationImpl) CreateUser(user defs.UserGettable) error { - - return nil +func (q *queryImpl) GetUserByID(userID int64) (defs.UserGettable, error) { + user := &models.User{} + if result := q.db.Where(models.User{Model: gorm.Model{ID: uint(userID)}}).First(user); result.Error != nil { + return nil, result.Error + } + return user, nil +} + +func (m *mutationImpl) FirstOrCreateUser(user defs.UserGettable) (defs.UserGettable, error) { + record := &models.User{} + newUserAttrs := &models.User{} + newUserAttrs.FromUserInfo(user) + result := m.db.Where(models.User{AppleUserID: user.GetAppleUserID()}). + Attrs(newUserAttrs). + FirstOrCreate(record) + return record, result.Error } diff --git a/defs/entity.go b/defs/entity.go index 235213a..0cba39b 100644 --- a/defs/entity.go +++ b/defs/entity.go @@ -1,7 +1,8 @@ package defs type User struct { - UserID string `json:"user_id"` + ID int64 `json:"user_id"` + AppleUserID string `json:"apple_user_id"` Name string `json:"name"` Username string `json:"username"` Email string `json:"email"` diff --git a/defs/user_info.go b/defs/user_info.go index 0700d66..fa78179 100644 --- a/defs/user_info.go +++ b/defs/user_info.go @@ -1,11 +1,16 @@ package defs +import "github.com/golang-jwt/jwt/v5" + type UserGettable interface { - GetUserID() string + GetUserID() int64 + GetAppleUserID() string GetName() string GetUsername() string GetEmail() string GetIsPrivateEmail() bool GetEmailVerified() bool + FromJWTClaims(jwt.MapClaims) ToUser() User + ToMap() map[string]string } diff --git a/go.mod b/go.mod index 3d7ad92..fc52f82 100644 --- a/go.mod +++ b/go.mod @@ -63,7 +63,7 @@ require ( github.com/sagikazarmark/locafero v0.4.0 // indirect github.com/sagikazarmark/slog-shim v0.1.0 // indirect github.com/spf13/afero v1.11.0 // indirect - github.com/spf13/cast v1.6.0 // indirect + github.com/spf13/cast v1.6.0 github.com/spf13/pflag v1.0.5 // indirect github.com/subosito/gotenv v1.6.0 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect diff --git a/middleware/auth.go b/middleware/auth.go index 79a613e..70f0186 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -1,19 +1,23 @@ package middleware import ( + "errors" "net/http" + "strings" "github.com/Timothylock/go-signin-with-apple/apple" + "github.com/golang-jwt/jwt/v5" "github.com/nose7en/ToyBoomServer/common" "github.com/nose7en/ToyBoomServer/config" - "github.com/nose7en/ToyBoomServer/defs" + "github.com/nose7en/ToyBoomServer/models" "github.com/nose7en/ToyBoomServer/rpc" + "github.com/nose7en/ToyBoomServer/utils" "github.com/spf13/cast" "github.com/gin-gonic/gin" ) -func ValidateAppleAppToken() func(c *gin.Context) { +func ValidateAppleAppLoginCode() func(c *gin.Context) { return func(c *gin.Context) { code := c.GetHeader(common.TokenKey) resp, err := rpc.GetManager().AppleCli().VerifyAppToken(c, code) @@ -47,8 +51,8 @@ func ValidateAppleAppToken() func(c *gin.Context) { emailVerified := cast.ToBool((*claim)["email_verified"]) isPrivateEmail := cast.ToBool((*claim)["is_private_email"]) - userInfo := &defs.User{ - UserID: unique, + userInfo := &models.User{ + AppleUserID: unique, Email: email, IsPrivateEmail: isPrivateEmail, EmailVerified: emailVerified, @@ -57,3 +61,74 @@ func ValidateAppleAppToken() func(c *gin.Context) { c.Set(common.UserInfoKey, userInfo) } } + +func ValidateToken() func(c *gin.Context) { + return func(c *gin.Context) { + tokens, err := GetAuthTokensFromAny(c) + if err != nil { + common.Logger(c).WithError(err).Error("failed to get auth token") + c.AbortWithStatusJSON(http.StatusOK, common.UnAuth("failed to get auth token")) + return + } + + var clms jwt.MapClaims + for tokenType, tokenStr := range tokens { + clms, err = validateToken(c, tokenStr) + if err == nil { + common.Logger(c).Infof("jwt middleware parse token success, token type: %s", tokenType) + break + } + } + + if err != nil { + common.Logger(c).WithError(err).Errorf("jwt middleware parse token error") + c.JSON(http.StatusOK, common.UnAuth("invalid authorization")) + c.Abort() + return + } + + userInfo := &models.User{} + userInfo.FromJWTClaims(clms) + if userInfo.GetUserID() <= 0 { + common.Logger(c).Errorf("failed to build user info from token") + c.AbortWithStatusJSON(http.StatusOK, common.UnAuth("invalid authorization")) + return + } + + common.Logger(c).Infof("token auth success, user info: %+v", userInfo) + c.Set(common.UserInfoKey, userInfo) + } +} + +func validateToken(ctx *gin.Context, tokenStr string) (u jwt.MapClaims, err error) { + if tokenStr == "" { + return nil, errors.New("token_str is empty") + } + + if t, err := utils.ParseToken(config.GetSettings().JWTConfig.Secret, tokenStr); err == nil { + for k, v := range t { + ctx.Set(k, v) + } + ctx.Set(common.TokenKey, tokenStr) + return t, nil + } + return nil, err +} + +func GetAuthTokensFromAny(c *gin.Context) (map[string]string, error) { + ans := map[string]string{} + if headerTokenStr := c.Request.Header.Get(common.AuthorizationKey); len(headerTokenStr) > 0 { + headerToken := strings.Split(headerTokenStr, " ") + if len(headerToken) == 2 { + ans[common.TokenKey] = headerToken[1] + } else { + ans[common.TokenKey] = headerTokenStr + } + } + + if len(ans) == 0 { + return nil, errors.New("auth token is empty") + } + + return ans, nil +} diff --git a/middleware/trace.go b/middleware/trace.go new file mode 100644 index 0000000..2d1089b --- /dev/null +++ b/middleware/trace.go @@ -0,0 +1,25 @@ +package middleware + +import ( + "github.com/gin-gonic/gin" + "github.com/nose7en/ToyBoomServer/common" + "github.com/nose7en/ToyBoomServer/utils" +) + +func Trace() func(c *gin.Context) { + return func(c *gin.Context) { + UpStreamTraceID := c.GetHeader(common.TraceIDKey) + if len(UpStreamTraceID) > 0 { + c.Set(common.TraceIDKey, UpStreamTraceID) + common.Logger(c).Infof("get a common request, upstream traceid %v", + c.GetString(common.TraceIDKey)) + c.Next() + return + } + + c.Set(common.TraceIDKey, utils.GenerateUID()) + common.Logger(c).Infof("get a common request, gen traceid %v", + c.GetString(common.TraceIDKey)) + c.Next() + } +} diff --git a/models/user.go b/models/user.go index d5ae315..fea3c1c 100644 --- a/models/user.go +++ b/models/user.go @@ -1,7 +1,11 @@ package models import ( + "fmt" + + "github.com/golang-jwt/jwt/v5" "github.com/nose7en/ToyBoomServer/defs" + "github.com/spf13/cast" "gorm.io/gorm" ) @@ -17,16 +21,16 @@ func NewUserGettable(opt ...func(*User)) defs.UserGettable { type User struct { gorm.Model - UserID string `json:"user_id"` // user id from apple - Name string `json:"name"` // ToyBoom's user name - Username string `json:"username"` // user name from apple - Email string `json:"email"` // email from apple + AppleUserID string `json:"apple_user_id" gorm:"uniqueIndex"` // user id from apple + Name string `json:"name"` // ToyBoom's user name + Username string `json:"username"` // user name from apple + Email string `json:"email"` // email from apple IsPrivateEmail bool `json:"is_private_email"` EmailVerified bool `json:"email_verified"` } -func (u *User) FillWithUserInfo(userInfo defs.UserGettable) { - u.UserID = userInfo.GetUserID() +func (u *User) FromUserInfo(userInfo defs.UserGettable) { + u.AppleUserID = userInfo.GetAppleUserID() u.Name = userInfo.GetName() u.Username = userInfo.GetUsername() u.Email = userInfo.GetEmail() @@ -34,6 +38,28 @@ func (u *User) FillWithUserInfo(userInfo defs.UserGettable) { u.EmailVerified = userInfo.GetEmailVerified() } +func (u *User) FromJWTClaims(claims jwt.MapClaims) { + u.ID = cast.ToUint(claims["user_id"]) + u.AppleUserID = cast.ToString(claims["apple_user_id"]) + u.Name = cast.ToString(claims["name"]) + u.Username = cast.ToString(claims["username"]) + u.Email = cast.ToString(claims["email"]) + u.IsPrivateEmail = cast.ToBool(claims["is_private_email"]) + u.EmailVerified = cast.ToBool(claims["email_verified"]) +} + +func (u *User) ToMap() map[string]string { + return map[string]string{ + "user_id": fmt.Sprint(u.ID), + "apple_user_id": u.AppleUserID, + "name": u.Name, + "username": u.Username, + "email": u.Email, + "is_private_email": fmt.Sprint(u.IsPrivateEmail), + "email_verified": fmt.Sprint(u.EmailVerified), + } +} + func (u *User) GetName() string { return u.Name } @@ -42,8 +68,12 @@ func (u *User) GetUsername() string { return u.Username } -func (u *User) GetUserID() string { - return u.UserID +func (u *User) GetAppleUserID() string { + return u.AppleUserID +} + +func (u *User) GetUserID() int64 { + return int64(u.ID) } func (u *User) GetEmail() string { @@ -60,9 +90,10 @@ func (u *User) GetEmailVerified() bool { func (u *User) ToUser() defs.User { return defs.User{ - UserID: u.UserID, - Name: u.Name, - Username: u.Username, - Email: u.Email, + ID: int64(u.ID), + AppleUserID: u.AppleUserID, + Name: u.Name, + Username: u.Username, + Email: u.Email, } } diff --git a/storage/db.go b/storage/db.go index 7f641da..e8bd4bd 100644 --- a/storage/db.go +++ b/storage/db.go @@ -1,6 +1,9 @@ package storage import ( + "context" + + "github.com/nose7en/ToyBoomServer/common" "gorm.io/gorm" ) @@ -24,7 +27,9 @@ type dbManagerImpl struct { func (dbm *dbManagerImpl) Init(tables ...interface{}) { dbs := dbm.DBs[DefaultDBName] for _, db := range dbs { - db.AutoMigrate(tables...) + if err := db.AutoMigrate(tables...); err != nil { + common.Logger(context.Background()).WithError(err).Fatalf("auto migrate error!!!") + } } }