feat: login with apple

This commit is contained in:
VaalaCat 2024-09-04 18:15:49 +00:00
parent 13148b95e3
commit 0948d23239
13 changed files with 268 additions and 41 deletions

View File

@ -17,13 +17,13 @@ func Router() *gin.Engine {
userRouter := v1.Group("/user")
{
userRouter.POST("/create", middleware.ValidateAppleAppToken(), common.Wrapper(user.CreateUser))
userRouter.GET("/info", middleware.ValidateAppleAppToken(), common.Wrapper(user.GetUserInfo))
userRouter.POST("/login-with-apple", middleware.ValidateAppleAppLoginCode(), common.Wrapper(user.LoginWithApple))
userRouter.GET("/info", middleware.ValidateToken(), common.Wrapper(user.GetUserInfo))
}
if config.IsDebug() {
// for debug
v1.GET("/ping", middleware.ValidateAppleAppToken(), func(ctx *gin.Context) { ctx.JSON(200, gin.H{"message": "pong"}) })
v1.GET("/ping", middleware.ValidateToken(), func(ctx *gin.Context) { ctx.JSON(200, gin.H{"message": "pong"}) })
}
return r
}

View File

@ -2,23 +2,67 @@ package user
import (
"context"
"time"
"github.com/nose7en/ToyBoomServer/common"
"github.com/nose7en/ToyBoomServer/config"
"github.com/nose7en/ToyBoomServer/dao"
"github.com/nose7en/ToyBoomServer/defs"
"github.com/nose7en/ToyBoomServer/models"
"github.com/nose7en/ToyBoomServer/utils"
)
func CreateUser(c context.Context, req *defs.CommonRequest) (*defs.CommonResponse, error) {
func LoginWithApple(c context.Context, req *defs.CommonRequest) (*defs.GetUserAuthTokenResponse, error) {
userInfo := common.GetUser(c)
newUser := &models.User{}
newUser.FillWithUserInfo(userInfo)
if err := dao.NewMutation().CreateUser(newUser); err != nil {
if userInfo.GetUserID() > 0 {
userInfo, err := dao.NewQuery().GetUserByID(userInfo.GetUserID())
if err != nil {
common.Logger(c).WithError(err).Errorf("failed to get user info")
return nil, err
}
newUserToken, err := newUserToken(userInfo)
if err != nil {
common.Logger(c).WithError(err).Errorf("failed to new user token")
return nil, err
}
return &defs.GetUserAuthTokenResponse{
Status: &defs.Status{Code: defs.RespCode_SUCCESS, Message: "user exists"},
Token: newUserToken,
}, nil
}
newUser, err := dao.NewMutation().FirstOrCreateUser(userInfo)
if err != nil {
common.Logger(c).WithError(err).Errorf("failed to create user")
return nil, err
}
return &defs.CommonResponse{
newUserToken, err := newUserToken(newUser)
if err != nil {
common.Logger(c).WithError(err).Errorf("failed to new user token")
return nil, err
}
common.Logger(c).Infof("create user success, user info record is %+v", newUser)
return &defs.GetUserAuthTokenResponse{
Status: &defs.Status{Code: defs.RespCode_SUCCESS, Message: defs.RespMessage_SUCCESS},
Token: newUserToken,
}, nil
}
func newUserToken(userInfo defs.UserGettable) (string, error) {
token, err := utils.GetJwtTokenFromMap(
config.GetSettings().JWTConfig.Secret,
time.Now().Unix(),
config.GetSettings().JWTConfig.ExpireSec,
userInfo.ToMap(),
)
if err != nil {
return "", err
}
return token, nil
}

View File

@ -4,6 +4,7 @@ import (
"context"
"github.com/nose7en/ToyBoomServer/defs"
"github.com/spf13/cast"
)
func GetUser(c context.Context) defs.UserGettable {
@ -19,3 +20,17 @@ func GetUser(c context.Context) defs.UserGettable {
return u
}
func GetToken(c context.Context) string {
return GetStrValue(c, TokenKey)
}
func GetStrValue(c context.Context, key string) string {
val := c.Value(key)
return cast.ToString(val)
}
func GetStrValueE(c context.Context, key string) (string, error) {
val := c.Value(key)
return cast.ToStringE(val)
}

View File

@ -33,7 +33,8 @@ type RedisConf struct {
}
type JWTConfig struct {
Secret string `mapstructure:"secret"`
Secret string `mapstructure:"secret"`
ExpireSec int64 `mapstructure:"expire_sec"`
}
type AppleConf struct {
@ -62,6 +63,7 @@ func fillDefaultSettings() {
viper.SetDefault("debug", false)
viper.SetDefault("db.type", "sqlite")
viper.SetDefault("db.dsn", "toyboom.db")
viper.SetDefault("jwt.expire_sec", 86400*30) // 30 days
}
func setConfigParams() {

View File

@ -2,26 +2,30 @@ package dao
import (
"github.com/nose7en/ToyBoomServer/defs"
"github.com/nose7en/ToyBoomServer/storage"
"gorm.io/gorm"
)
type Query interface {
GetUserByAppleUserID(appleUserID string) (defs.UserGettable, error)
GetUserByID(userID int64) (defs.UserGettable, error)
}
type Mutation interface {
CreateUser(user defs.UserGettable) error
FirstOrCreateUser(user defs.UserGettable) (defs.UserGettable, error)
}
var _ Query = (*queryImpl)(nil)
var _ Mutation = (*mutationImpl)(nil)
type queryImpl struct{}
type mutationImpl struct{}
type queryImpl struct{ db *gorm.DB }
type mutationImpl struct{ db *gorm.DB }
func NewQuery() Query {
return &queryImpl{}
return &queryImpl{db: storage.GetDBManager().GetDefaultDB()}
}
func NewMutation() Mutation {
return &mutationImpl{}
return &mutationImpl{db: storage.GetDBManager().GetDefaultDB()}
}

View File

@ -1,13 +1,33 @@
package dao
import "github.com/nose7en/ToyBoomServer/defs"
import (
"github.com/nose7en/ToyBoomServer/defs"
"github.com/nose7en/ToyBoomServer/models"
"gorm.io/gorm"
)
func (q *queryImpl) GetUserByAppleUserID(appleUserID string) (defs.UserGettable, error) {
return nil, nil
user := &models.User{}
if result := q.db.Where(models.User{AppleUserID: appleUserID}).First(user); result.Error != nil {
return nil, result.Error
}
return user, nil
}
func (m *mutationImpl) CreateUser(user defs.UserGettable) error {
return nil
func (q *queryImpl) GetUserByID(userID int64) (defs.UserGettable, error) {
user := &models.User{}
if result := q.db.Where(models.User{Model: gorm.Model{ID: uint(userID)}}).First(user); result.Error != nil {
return nil, result.Error
}
return user, nil
}
func (m *mutationImpl) FirstOrCreateUser(user defs.UserGettable) (defs.UserGettable, error) {
record := &models.User{}
newUserAttrs := &models.User{}
newUserAttrs.FromUserInfo(user)
result := m.db.Where(models.User{AppleUserID: user.GetAppleUserID()}).
Attrs(newUserAttrs).
FirstOrCreate(record)
return record, result.Error
}

View File

@ -1,7 +1,8 @@
package defs
type User struct {
UserID string `json:"user_id"`
ID int64 `json:"user_id"`
AppleUserID string `json:"apple_user_id"`
Name string `json:"name"`
Username string `json:"username"`
Email string `json:"email"`

View File

@ -1,11 +1,16 @@
package defs
import "github.com/golang-jwt/jwt/v5"
type UserGettable interface {
GetUserID() string
GetUserID() int64
GetAppleUserID() string
GetName() string
GetUsername() string
GetEmail() string
GetIsPrivateEmail() bool
GetEmailVerified() bool
FromJWTClaims(jwt.MapClaims)
ToUser() User
ToMap() map[string]string
}

2
go.mod
View File

@ -63,7 +63,7 @@ require (
github.com/sagikazarmark/locafero v0.4.0 // indirect
github.com/sagikazarmark/slog-shim v0.1.0 // indirect
github.com/spf13/afero v1.11.0 // indirect
github.com/spf13/cast v1.6.0 // indirect
github.com/spf13/cast v1.6.0
github.com/spf13/pflag v1.0.5 // indirect
github.com/subosito/gotenv v1.6.0 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect

View File

@ -1,19 +1,23 @@
package middleware
import (
"errors"
"net/http"
"strings"
"github.com/Timothylock/go-signin-with-apple/apple"
"github.com/golang-jwt/jwt/v5"
"github.com/nose7en/ToyBoomServer/common"
"github.com/nose7en/ToyBoomServer/config"
"github.com/nose7en/ToyBoomServer/defs"
"github.com/nose7en/ToyBoomServer/models"
"github.com/nose7en/ToyBoomServer/rpc"
"github.com/nose7en/ToyBoomServer/utils"
"github.com/spf13/cast"
"github.com/gin-gonic/gin"
)
func ValidateAppleAppToken() func(c *gin.Context) {
func ValidateAppleAppLoginCode() func(c *gin.Context) {
return func(c *gin.Context) {
code := c.GetHeader(common.TokenKey)
resp, err := rpc.GetManager().AppleCli().VerifyAppToken(c, code)
@ -47,8 +51,8 @@ func ValidateAppleAppToken() func(c *gin.Context) {
emailVerified := cast.ToBool((*claim)["email_verified"])
isPrivateEmail := cast.ToBool((*claim)["is_private_email"])
userInfo := &defs.User{
UserID: unique,
userInfo := &models.User{
AppleUserID: unique,
Email: email,
IsPrivateEmail: isPrivateEmail,
EmailVerified: emailVerified,
@ -57,3 +61,74 @@ func ValidateAppleAppToken() func(c *gin.Context) {
c.Set(common.UserInfoKey, userInfo)
}
}
func ValidateToken() func(c *gin.Context) {
return func(c *gin.Context) {
tokens, err := GetAuthTokensFromAny(c)
if err != nil {
common.Logger(c).WithError(err).Error("failed to get auth token")
c.AbortWithStatusJSON(http.StatusOK, common.UnAuth("failed to get auth token"))
return
}
var clms jwt.MapClaims
for tokenType, tokenStr := range tokens {
clms, err = validateToken(c, tokenStr)
if err == nil {
common.Logger(c).Infof("jwt middleware parse token success, token type: %s", tokenType)
break
}
}
if err != nil {
common.Logger(c).WithError(err).Errorf("jwt middleware parse token error")
c.JSON(http.StatusOK, common.UnAuth("invalid authorization"))
c.Abort()
return
}
userInfo := &models.User{}
userInfo.FromJWTClaims(clms)
if userInfo.GetUserID() <= 0 {
common.Logger(c).Errorf("failed to build user info from token")
c.AbortWithStatusJSON(http.StatusOK, common.UnAuth("invalid authorization"))
return
}
common.Logger(c).Infof("token auth success, user info: %+v", userInfo)
c.Set(common.UserInfoKey, userInfo)
}
}
func validateToken(ctx *gin.Context, tokenStr string) (u jwt.MapClaims, err error) {
if tokenStr == "" {
return nil, errors.New("token_str is empty")
}
if t, err := utils.ParseToken(config.GetSettings().JWTConfig.Secret, tokenStr); err == nil {
for k, v := range t {
ctx.Set(k, v)
}
ctx.Set(common.TokenKey, tokenStr)
return t, nil
}
return nil, err
}
func GetAuthTokensFromAny(c *gin.Context) (map[string]string, error) {
ans := map[string]string{}
if headerTokenStr := c.Request.Header.Get(common.AuthorizationKey); len(headerTokenStr) > 0 {
headerToken := strings.Split(headerTokenStr, " ")
if len(headerToken) == 2 {
ans[common.TokenKey] = headerToken[1]
} else {
ans[common.TokenKey] = headerTokenStr
}
}
if len(ans) == 0 {
return nil, errors.New("auth token is empty")
}
return ans, nil
}

25
middleware/trace.go Normal file
View File

@ -0,0 +1,25 @@
package middleware
import (
"github.com/gin-gonic/gin"
"github.com/nose7en/ToyBoomServer/common"
"github.com/nose7en/ToyBoomServer/utils"
)
func Trace() func(c *gin.Context) {
return func(c *gin.Context) {
UpStreamTraceID := c.GetHeader(common.TraceIDKey)
if len(UpStreamTraceID) > 0 {
c.Set(common.TraceIDKey, UpStreamTraceID)
common.Logger(c).Infof("get a common request, upstream traceid %v",
c.GetString(common.TraceIDKey))
c.Next()
return
}
c.Set(common.TraceIDKey, utils.GenerateUID())
common.Logger(c).Infof("get a common request, gen traceid %v",
c.GetString(common.TraceIDKey))
c.Next()
}
}

View File

@ -1,7 +1,11 @@
package models
import (
"fmt"
"github.com/golang-jwt/jwt/v5"
"github.com/nose7en/ToyBoomServer/defs"
"github.com/spf13/cast"
"gorm.io/gorm"
)
@ -17,16 +21,16 @@ func NewUserGettable(opt ...func(*User)) defs.UserGettable {
type User struct {
gorm.Model
UserID string `json:"user_id"` // user id from apple
Name string `json:"name"` // ToyBoom's user name
Username string `json:"username"` // user name from apple
Email string `json:"email"` // email from apple
AppleUserID string `json:"apple_user_id" gorm:"uniqueIndex"` // user id from apple
Name string `json:"name"` // ToyBoom's user name
Username string `json:"username"` // user name from apple
Email string `json:"email"` // email from apple
IsPrivateEmail bool `json:"is_private_email"`
EmailVerified bool `json:"email_verified"`
}
func (u *User) FillWithUserInfo(userInfo defs.UserGettable) {
u.UserID = userInfo.GetUserID()
func (u *User) FromUserInfo(userInfo defs.UserGettable) {
u.AppleUserID = userInfo.GetAppleUserID()
u.Name = userInfo.GetName()
u.Username = userInfo.GetUsername()
u.Email = userInfo.GetEmail()
@ -34,6 +38,28 @@ func (u *User) FillWithUserInfo(userInfo defs.UserGettable) {
u.EmailVerified = userInfo.GetEmailVerified()
}
func (u *User) FromJWTClaims(claims jwt.MapClaims) {
u.ID = cast.ToUint(claims["user_id"])
u.AppleUserID = cast.ToString(claims["apple_user_id"])
u.Name = cast.ToString(claims["name"])
u.Username = cast.ToString(claims["username"])
u.Email = cast.ToString(claims["email"])
u.IsPrivateEmail = cast.ToBool(claims["is_private_email"])
u.EmailVerified = cast.ToBool(claims["email_verified"])
}
func (u *User) ToMap() map[string]string {
return map[string]string{
"user_id": fmt.Sprint(u.ID),
"apple_user_id": u.AppleUserID,
"name": u.Name,
"username": u.Username,
"email": u.Email,
"is_private_email": fmt.Sprint(u.IsPrivateEmail),
"email_verified": fmt.Sprint(u.EmailVerified),
}
}
func (u *User) GetName() string {
return u.Name
}
@ -42,8 +68,12 @@ func (u *User) GetUsername() string {
return u.Username
}
func (u *User) GetUserID() string {
return u.UserID
func (u *User) GetAppleUserID() string {
return u.AppleUserID
}
func (u *User) GetUserID() int64 {
return int64(u.ID)
}
func (u *User) GetEmail() string {
@ -60,9 +90,10 @@ func (u *User) GetEmailVerified() bool {
func (u *User) ToUser() defs.User {
return defs.User{
UserID: u.UserID,
Name: u.Name,
Username: u.Username,
Email: u.Email,
ID: int64(u.ID),
AppleUserID: u.AppleUserID,
Name: u.Name,
Username: u.Username,
Email: u.Email,
}
}

View File

@ -1,6 +1,9 @@
package storage
import (
"context"
"github.com/nose7en/ToyBoomServer/common"
"gorm.io/gorm"
)
@ -24,7 +27,9 @@ type dbManagerImpl struct {
func (dbm *dbManagerImpl) Init(tables ...interface{}) {
dbs := dbm.DBs[DefaultDBName]
for _, db := range dbs {
db.AutoMigrate(tables...)
if err := db.AutoMigrate(tables...); err != nil {
common.Logger(context.Background()).WithError(err).Fatalf("auto migrate error!!!")
}
}
}