feat: login with apple

This commit is contained in:
VaalaCat 2024-09-04 18:15:49 +00:00
parent 13148b95e3
commit 0948d23239
13 changed files with 268 additions and 41 deletions

View File

@ -17,13 +17,13 @@ func Router() *gin.Engine {
userRouter := v1.Group("/user") userRouter := v1.Group("/user")
{ {
userRouter.POST("/create", middleware.ValidateAppleAppToken(), common.Wrapper(user.CreateUser)) userRouter.POST("/login-with-apple", middleware.ValidateAppleAppLoginCode(), common.Wrapper(user.LoginWithApple))
userRouter.GET("/info", middleware.ValidateAppleAppToken(), common.Wrapper(user.GetUserInfo)) userRouter.GET("/info", middleware.ValidateToken(), common.Wrapper(user.GetUserInfo))
} }
if config.IsDebug() { if config.IsDebug() {
// for debug // for debug
v1.GET("/ping", middleware.ValidateAppleAppToken(), func(ctx *gin.Context) { ctx.JSON(200, gin.H{"message": "pong"}) }) v1.GET("/ping", middleware.ValidateToken(), func(ctx *gin.Context) { ctx.JSON(200, gin.H{"message": "pong"}) })
} }
return r return r
} }

View File

@ -2,23 +2,67 @@ package user
import ( import (
"context" "context"
"time"
"github.com/nose7en/ToyBoomServer/common" "github.com/nose7en/ToyBoomServer/common"
"github.com/nose7en/ToyBoomServer/config"
"github.com/nose7en/ToyBoomServer/dao" "github.com/nose7en/ToyBoomServer/dao"
"github.com/nose7en/ToyBoomServer/defs" "github.com/nose7en/ToyBoomServer/defs"
"github.com/nose7en/ToyBoomServer/models" "github.com/nose7en/ToyBoomServer/utils"
) )
func CreateUser(c context.Context, req *defs.CommonRequest) (*defs.CommonResponse, error) { func LoginWithApple(c context.Context, req *defs.CommonRequest) (*defs.GetUserAuthTokenResponse, error) {
userInfo := common.GetUser(c) userInfo := common.GetUser(c)
newUser := &models.User{}
newUser.FillWithUserInfo(userInfo)
if err := dao.NewMutation().CreateUser(newUser); err != nil { if userInfo.GetUserID() > 0 {
userInfo, err := dao.NewQuery().GetUserByID(userInfo.GetUserID())
if err != nil {
common.Logger(c).WithError(err).Errorf("failed to get user info")
return nil, err return nil, err
} }
return &defs.CommonResponse{ newUserToken, err := newUserToken(userInfo)
Status: &defs.Status{Code: defs.RespCode_SUCCESS, Message: defs.RespMessage_SUCCESS}, if err != nil {
common.Logger(c).WithError(err).Errorf("failed to new user token")
return nil, err
}
return &defs.GetUserAuthTokenResponse{
Status: &defs.Status{Code: defs.RespCode_SUCCESS, Message: "user exists"},
Token: newUserToken,
}, nil }, nil
} }
newUser, err := dao.NewMutation().FirstOrCreateUser(userInfo)
if err != nil {
common.Logger(c).WithError(err).Errorf("failed to create user")
return nil, err
}
newUserToken, err := newUserToken(newUser)
if err != nil {
common.Logger(c).WithError(err).Errorf("failed to new user token")
return nil, err
}
common.Logger(c).Infof("create user success, user info record is %+v", newUser)
return &defs.GetUserAuthTokenResponse{
Status: &defs.Status{Code: defs.RespCode_SUCCESS, Message: defs.RespMessage_SUCCESS},
Token: newUserToken,
}, nil
}
func newUserToken(userInfo defs.UserGettable) (string, error) {
token, err := utils.GetJwtTokenFromMap(
config.GetSettings().JWTConfig.Secret,
time.Now().Unix(),
config.GetSettings().JWTConfig.ExpireSec,
userInfo.ToMap(),
)
if err != nil {
return "", err
}
return token, nil
}

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"github.com/nose7en/ToyBoomServer/defs" "github.com/nose7en/ToyBoomServer/defs"
"github.com/spf13/cast"
) )
func GetUser(c context.Context) defs.UserGettable { func GetUser(c context.Context) defs.UserGettable {
@ -19,3 +20,17 @@ func GetUser(c context.Context) defs.UserGettable {
return u return u
} }
func GetToken(c context.Context) string {
return GetStrValue(c, TokenKey)
}
func GetStrValue(c context.Context, key string) string {
val := c.Value(key)
return cast.ToString(val)
}
func GetStrValueE(c context.Context, key string) (string, error) {
val := c.Value(key)
return cast.ToStringE(val)
}

View File

@ -34,6 +34,7 @@ type RedisConf struct {
type JWTConfig struct { type JWTConfig struct {
Secret string `mapstructure:"secret"` Secret string `mapstructure:"secret"`
ExpireSec int64 `mapstructure:"expire_sec"`
} }
type AppleConf struct { type AppleConf struct {
@ -62,6 +63,7 @@ func fillDefaultSettings() {
viper.SetDefault("debug", false) viper.SetDefault("debug", false)
viper.SetDefault("db.type", "sqlite") viper.SetDefault("db.type", "sqlite")
viper.SetDefault("db.dsn", "toyboom.db") viper.SetDefault("db.dsn", "toyboom.db")
viper.SetDefault("jwt.expire_sec", 86400*30) // 30 days
} }
func setConfigParams() { func setConfigParams() {

View File

@ -2,26 +2,30 @@ package dao
import ( import (
"github.com/nose7en/ToyBoomServer/defs" "github.com/nose7en/ToyBoomServer/defs"
"github.com/nose7en/ToyBoomServer/storage"
"gorm.io/gorm"
) )
type Query interface { type Query interface {
GetUserByAppleUserID(appleUserID string) (defs.UserGettable, error) GetUserByAppleUserID(appleUserID string) (defs.UserGettable, error)
GetUserByID(userID int64) (defs.UserGettable, error)
} }
type Mutation interface { type Mutation interface {
CreateUser(user defs.UserGettable) error FirstOrCreateUser(user defs.UserGettable) (defs.UserGettable, error)
} }
var _ Query = (*queryImpl)(nil) var _ Query = (*queryImpl)(nil)
var _ Mutation = (*mutationImpl)(nil) var _ Mutation = (*mutationImpl)(nil)
type queryImpl struct{} type queryImpl struct{ db *gorm.DB }
type mutationImpl struct{}
type mutationImpl struct{ db *gorm.DB }
func NewQuery() Query { func NewQuery() Query {
return &queryImpl{} return &queryImpl{db: storage.GetDBManager().GetDefaultDB()}
} }
func NewMutation() Mutation { func NewMutation() Mutation {
return &mutationImpl{} return &mutationImpl{db: storage.GetDBManager().GetDefaultDB()}
} }

View File

@ -1,13 +1,33 @@
package dao package dao
import "github.com/nose7en/ToyBoomServer/defs" import (
"github.com/nose7en/ToyBoomServer/defs"
"github.com/nose7en/ToyBoomServer/models"
"gorm.io/gorm"
)
func (q *queryImpl) GetUserByAppleUserID(appleUserID string) (defs.UserGettable, error) { func (q *queryImpl) GetUserByAppleUserID(appleUserID string) (defs.UserGettable, error) {
user := &models.User{}
return nil, nil if result := q.db.Where(models.User{AppleUserID: appleUserID}).First(user); result.Error != nil {
return nil, result.Error
}
return user, nil
} }
func (m *mutationImpl) CreateUser(user defs.UserGettable) error { func (q *queryImpl) GetUserByID(userID int64) (defs.UserGettable, error) {
user := &models.User{}
return nil if result := q.db.Where(models.User{Model: gorm.Model{ID: uint(userID)}}).First(user); result.Error != nil {
return nil, result.Error
}
return user, nil
}
func (m *mutationImpl) FirstOrCreateUser(user defs.UserGettable) (defs.UserGettable, error) {
record := &models.User{}
newUserAttrs := &models.User{}
newUserAttrs.FromUserInfo(user)
result := m.db.Where(models.User{AppleUserID: user.GetAppleUserID()}).
Attrs(newUserAttrs).
FirstOrCreate(record)
return record, result.Error
} }

View File

@ -1,7 +1,8 @@
package defs package defs
type User struct { type User struct {
UserID string `json:"user_id"` ID int64 `json:"user_id"`
AppleUserID string `json:"apple_user_id"`
Name string `json:"name"` Name string `json:"name"`
Username string `json:"username"` Username string `json:"username"`
Email string `json:"email"` Email string `json:"email"`

View File

@ -1,11 +1,16 @@
package defs package defs
import "github.com/golang-jwt/jwt/v5"
type UserGettable interface { type UserGettable interface {
GetUserID() string GetUserID() int64
GetAppleUserID() string
GetName() string GetName() string
GetUsername() string GetUsername() string
GetEmail() string GetEmail() string
GetIsPrivateEmail() bool GetIsPrivateEmail() bool
GetEmailVerified() bool GetEmailVerified() bool
FromJWTClaims(jwt.MapClaims)
ToUser() User ToUser() User
ToMap() map[string]string
} }

2
go.mod
View File

@ -63,7 +63,7 @@ require (
github.com/sagikazarmark/locafero v0.4.0 // indirect github.com/sagikazarmark/locafero v0.4.0 // indirect
github.com/sagikazarmark/slog-shim v0.1.0 // indirect github.com/sagikazarmark/slog-shim v0.1.0 // indirect
github.com/spf13/afero v1.11.0 // indirect github.com/spf13/afero v1.11.0 // indirect
github.com/spf13/cast v1.6.0 // indirect github.com/spf13/cast v1.6.0
github.com/spf13/pflag v1.0.5 // indirect github.com/spf13/pflag v1.0.5 // indirect
github.com/subosito/gotenv v1.6.0 // indirect github.com/subosito/gotenv v1.6.0 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect

View File

@ -1,19 +1,23 @@
package middleware package middleware
import ( import (
"errors"
"net/http" "net/http"
"strings"
"github.com/Timothylock/go-signin-with-apple/apple" "github.com/Timothylock/go-signin-with-apple/apple"
"github.com/golang-jwt/jwt/v5"
"github.com/nose7en/ToyBoomServer/common" "github.com/nose7en/ToyBoomServer/common"
"github.com/nose7en/ToyBoomServer/config" "github.com/nose7en/ToyBoomServer/config"
"github.com/nose7en/ToyBoomServer/defs" "github.com/nose7en/ToyBoomServer/models"
"github.com/nose7en/ToyBoomServer/rpc" "github.com/nose7en/ToyBoomServer/rpc"
"github.com/nose7en/ToyBoomServer/utils"
"github.com/spf13/cast" "github.com/spf13/cast"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
func ValidateAppleAppToken() func(c *gin.Context) { func ValidateAppleAppLoginCode() func(c *gin.Context) {
return func(c *gin.Context) { return func(c *gin.Context) {
code := c.GetHeader(common.TokenKey) code := c.GetHeader(common.TokenKey)
resp, err := rpc.GetManager().AppleCli().VerifyAppToken(c, code) resp, err := rpc.GetManager().AppleCli().VerifyAppToken(c, code)
@ -47,8 +51,8 @@ func ValidateAppleAppToken() func(c *gin.Context) {
emailVerified := cast.ToBool((*claim)["email_verified"]) emailVerified := cast.ToBool((*claim)["email_verified"])
isPrivateEmail := cast.ToBool((*claim)["is_private_email"]) isPrivateEmail := cast.ToBool((*claim)["is_private_email"])
userInfo := &defs.User{ userInfo := &models.User{
UserID: unique, AppleUserID: unique,
Email: email, Email: email,
IsPrivateEmail: isPrivateEmail, IsPrivateEmail: isPrivateEmail,
EmailVerified: emailVerified, EmailVerified: emailVerified,
@ -57,3 +61,74 @@ func ValidateAppleAppToken() func(c *gin.Context) {
c.Set(common.UserInfoKey, userInfo) c.Set(common.UserInfoKey, userInfo)
} }
} }
func ValidateToken() func(c *gin.Context) {
return func(c *gin.Context) {
tokens, err := GetAuthTokensFromAny(c)
if err != nil {
common.Logger(c).WithError(err).Error("failed to get auth token")
c.AbortWithStatusJSON(http.StatusOK, common.UnAuth("failed to get auth token"))
return
}
var clms jwt.MapClaims
for tokenType, tokenStr := range tokens {
clms, err = validateToken(c, tokenStr)
if err == nil {
common.Logger(c).Infof("jwt middleware parse token success, token type: %s", tokenType)
break
}
}
if err != nil {
common.Logger(c).WithError(err).Errorf("jwt middleware parse token error")
c.JSON(http.StatusOK, common.UnAuth("invalid authorization"))
c.Abort()
return
}
userInfo := &models.User{}
userInfo.FromJWTClaims(clms)
if userInfo.GetUserID() <= 0 {
common.Logger(c).Errorf("failed to build user info from token")
c.AbortWithStatusJSON(http.StatusOK, common.UnAuth("invalid authorization"))
return
}
common.Logger(c).Infof("token auth success, user info: %+v", userInfo)
c.Set(common.UserInfoKey, userInfo)
}
}
func validateToken(ctx *gin.Context, tokenStr string) (u jwt.MapClaims, err error) {
if tokenStr == "" {
return nil, errors.New("token_str is empty")
}
if t, err := utils.ParseToken(config.GetSettings().JWTConfig.Secret, tokenStr); err == nil {
for k, v := range t {
ctx.Set(k, v)
}
ctx.Set(common.TokenKey, tokenStr)
return t, nil
}
return nil, err
}
func GetAuthTokensFromAny(c *gin.Context) (map[string]string, error) {
ans := map[string]string{}
if headerTokenStr := c.Request.Header.Get(common.AuthorizationKey); len(headerTokenStr) > 0 {
headerToken := strings.Split(headerTokenStr, " ")
if len(headerToken) == 2 {
ans[common.TokenKey] = headerToken[1]
} else {
ans[common.TokenKey] = headerTokenStr
}
}
if len(ans) == 0 {
return nil, errors.New("auth token is empty")
}
return ans, nil
}

25
middleware/trace.go Normal file
View File

@ -0,0 +1,25 @@
package middleware
import (
"github.com/gin-gonic/gin"
"github.com/nose7en/ToyBoomServer/common"
"github.com/nose7en/ToyBoomServer/utils"
)
func Trace() func(c *gin.Context) {
return func(c *gin.Context) {
UpStreamTraceID := c.GetHeader(common.TraceIDKey)
if len(UpStreamTraceID) > 0 {
c.Set(common.TraceIDKey, UpStreamTraceID)
common.Logger(c).Infof("get a common request, upstream traceid %v",
c.GetString(common.TraceIDKey))
c.Next()
return
}
c.Set(common.TraceIDKey, utils.GenerateUID())
common.Logger(c).Infof("get a common request, gen traceid %v",
c.GetString(common.TraceIDKey))
c.Next()
}
}

View File

@ -1,7 +1,11 @@
package models package models
import ( import (
"fmt"
"github.com/golang-jwt/jwt/v5"
"github.com/nose7en/ToyBoomServer/defs" "github.com/nose7en/ToyBoomServer/defs"
"github.com/spf13/cast"
"gorm.io/gorm" "gorm.io/gorm"
) )
@ -17,7 +21,7 @@ func NewUserGettable(opt ...func(*User)) defs.UserGettable {
type User struct { type User struct {
gorm.Model gorm.Model
UserID string `json:"user_id"` // user id from apple AppleUserID string `json:"apple_user_id" gorm:"uniqueIndex"` // user id from apple
Name string `json:"name"` // ToyBoom's user name Name string `json:"name"` // ToyBoom's user name
Username string `json:"username"` // user name from apple Username string `json:"username"` // user name from apple
Email string `json:"email"` // email from apple Email string `json:"email"` // email from apple
@ -25,8 +29,8 @@ type User struct {
EmailVerified bool `json:"email_verified"` EmailVerified bool `json:"email_verified"`
} }
func (u *User) FillWithUserInfo(userInfo defs.UserGettable) { func (u *User) FromUserInfo(userInfo defs.UserGettable) {
u.UserID = userInfo.GetUserID() u.AppleUserID = userInfo.GetAppleUserID()
u.Name = userInfo.GetName() u.Name = userInfo.GetName()
u.Username = userInfo.GetUsername() u.Username = userInfo.GetUsername()
u.Email = userInfo.GetEmail() u.Email = userInfo.GetEmail()
@ -34,6 +38,28 @@ func (u *User) FillWithUserInfo(userInfo defs.UserGettable) {
u.EmailVerified = userInfo.GetEmailVerified() u.EmailVerified = userInfo.GetEmailVerified()
} }
func (u *User) FromJWTClaims(claims jwt.MapClaims) {
u.ID = cast.ToUint(claims["user_id"])
u.AppleUserID = cast.ToString(claims["apple_user_id"])
u.Name = cast.ToString(claims["name"])
u.Username = cast.ToString(claims["username"])
u.Email = cast.ToString(claims["email"])
u.IsPrivateEmail = cast.ToBool(claims["is_private_email"])
u.EmailVerified = cast.ToBool(claims["email_verified"])
}
func (u *User) ToMap() map[string]string {
return map[string]string{
"user_id": fmt.Sprint(u.ID),
"apple_user_id": u.AppleUserID,
"name": u.Name,
"username": u.Username,
"email": u.Email,
"is_private_email": fmt.Sprint(u.IsPrivateEmail),
"email_verified": fmt.Sprint(u.EmailVerified),
}
}
func (u *User) GetName() string { func (u *User) GetName() string {
return u.Name return u.Name
} }
@ -42,8 +68,12 @@ func (u *User) GetUsername() string {
return u.Username return u.Username
} }
func (u *User) GetUserID() string { func (u *User) GetAppleUserID() string {
return u.UserID return u.AppleUserID
}
func (u *User) GetUserID() int64 {
return int64(u.ID)
} }
func (u *User) GetEmail() string { func (u *User) GetEmail() string {
@ -60,7 +90,8 @@ func (u *User) GetEmailVerified() bool {
func (u *User) ToUser() defs.User { func (u *User) ToUser() defs.User {
return defs.User{ return defs.User{
UserID: u.UserID, ID: int64(u.ID),
AppleUserID: u.AppleUserID,
Name: u.Name, Name: u.Name,
Username: u.Username, Username: u.Username,
Email: u.Email, Email: u.Email,

View File

@ -1,6 +1,9 @@
package storage package storage
import ( import (
"context"
"github.com/nose7en/ToyBoomServer/common"
"gorm.io/gorm" "gorm.io/gorm"
) )
@ -24,7 +27,9 @@ type dbManagerImpl struct {
func (dbm *dbManagerImpl) Init(tables ...interface{}) { func (dbm *dbManagerImpl) Init(tables ...interface{}) {
dbs := dbm.DBs[DefaultDBName] dbs := dbm.DBs[DefaultDBName]
for _, db := range dbs { for _, db := range dbs {
db.AutoMigrate(tables...) if err := db.AutoMigrate(tables...); err != nil {
common.Logger(context.Background()).WithError(err).Fatalf("auto migrate error!!!")
}
} }
} }