feat: login with apple
This commit is contained in:
parent
13148b95e3
commit
0948d23239
@ -17,13 +17,13 @@ func Router() *gin.Engine {
|
||||
|
||||
userRouter := v1.Group("/user")
|
||||
{
|
||||
userRouter.POST("/create", middleware.ValidateAppleAppToken(), common.Wrapper(user.CreateUser))
|
||||
userRouter.GET("/info", middleware.ValidateAppleAppToken(), common.Wrapper(user.GetUserInfo))
|
||||
userRouter.POST("/login-with-apple", middleware.ValidateAppleAppLoginCode(), common.Wrapper(user.LoginWithApple))
|
||||
userRouter.GET("/info", middleware.ValidateToken(), common.Wrapper(user.GetUserInfo))
|
||||
}
|
||||
|
||||
if config.IsDebug() {
|
||||
// for debug
|
||||
v1.GET("/ping", middleware.ValidateAppleAppToken(), func(ctx *gin.Context) { ctx.JSON(200, gin.H{"message": "pong"}) })
|
||||
v1.GET("/ping", middleware.ValidateToken(), func(ctx *gin.Context) { ctx.JSON(200, gin.H{"message": "pong"}) })
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
@ -2,23 +2,67 @@ package user
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/nose7en/ToyBoomServer/common"
|
||||
"github.com/nose7en/ToyBoomServer/config"
|
||||
"github.com/nose7en/ToyBoomServer/dao"
|
||||
"github.com/nose7en/ToyBoomServer/defs"
|
||||
"github.com/nose7en/ToyBoomServer/models"
|
||||
"github.com/nose7en/ToyBoomServer/utils"
|
||||
)
|
||||
|
||||
func CreateUser(c context.Context, req *defs.CommonRequest) (*defs.CommonResponse, error) {
|
||||
func LoginWithApple(c context.Context, req *defs.CommonRequest) (*defs.GetUserAuthTokenResponse, error) {
|
||||
userInfo := common.GetUser(c)
|
||||
newUser := &models.User{}
|
||||
newUser.FillWithUserInfo(userInfo)
|
||||
|
||||
if err := dao.NewMutation().CreateUser(newUser); err != nil {
|
||||
if userInfo.GetUserID() > 0 {
|
||||
userInfo, err := dao.NewQuery().GetUserByID(userInfo.GetUserID())
|
||||
if err != nil {
|
||||
common.Logger(c).WithError(err).Errorf("failed to get user info")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &defs.CommonResponse{
|
||||
newUserToken, err := newUserToken(userInfo)
|
||||
if err != nil {
|
||||
common.Logger(c).WithError(err).Errorf("failed to new user token")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &defs.GetUserAuthTokenResponse{
|
||||
Status: &defs.Status{Code: defs.RespCode_SUCCESS, Message: "user exists"},
|
||||
Token: newUserToken,
|
||||
}, nil
|
||||
}
|
||||
|
||||
newUser, err := dao.NewMutation().FirstOrCreateUser(userInfo)
|
||||
if err != nil {
|
||||
common.Logger(c).WithError(err).Errorf("failed to create user")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
newUserToken, err := newUserToken(newUser)
|
||||
if err != nil {
|
||||
common.Logger(c).WithError(err).Errorf("failed to new user token")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
common.Logger(c).Infof("create user success, user info record is %+v", newUser)
|
||||
return &defs.GetUserAuthTokenResponse{
|
||||
Status: &defs.Status{Code: defs.RespCode_SUCCESS, Message: defs.RespMessage_SUCCESS},
|
||||
Token: newUserToken,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func newUserToken(userInfo defs.UserGettable) (string, error) {
|
||||
token, err := utils.GetJwtTokenFromMap(
|
||||
config.GetSettings().JWTConfig.Secret,
|
||||
time.Now().Unix(),
|
||||
config.GetSettings().JWTConfig.ExpireSec,
|
||||
userInfo.ToMap(),
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
|
||||
"github.com/nose7en/ToyBoomServer/defs"
|
||||
"github.com/spf13/cast"
|
||||
)
|
||||
|
||||
func GetUser(c context.Context) defs.UserGettable {
|
||||
@ -19,3 +20,17 @@ func GetUser(c context.Context) defs.UserGettable {
|
||||
|
||||
return u
|
||||
}
|
||||
|
||||
func GetToken(c context.Context) string {
|
||||
return GetStrValue(c, TokenKey)
|
||||
}
|
||||
|
||||
func GetStrValue(c context.Context, key string) string {
|
||||
val := c.Value(key)
|
||||
return cast.ToString(val)
|
||||
}
|
||||
|
||||
func GetStrValueE(c context.Context, key string) (string, error) {
|
||||
val := c.Value(key)
|
||||
return cast.ToStringE(val)
|
||||
}
|
||||
|
@ -34,6 +34,7 @@ type RedisConf struct {
|
||||
|
||||
type JWTConfig struct {
|
||||
Secret string `mapstructure:"secret"`
|
||||
ExpireSec int64 `mapstructure:"expire_sec"`
|
||||
}
|
||||
|
||||
type AppleConf struct {
|
||||
@ -62,6 +63,7 @@ func fillDefaultSettings() {
|
||||
viper.SetDefault("debug", false)
|
||||
viper.SetDefault("db.type", "sqlite")
|
||||
viper.SetDefault("db.dsn", "toyboom.db")
|
||||
viper.SetDefault("jwt.expire_sec", 86400*30) // 30 days
|
||||
}
|
||||
|
||||
func setConfigParams() {
|
||||
|
@ -2,26 +2,30 @@ package dao
|
||||
|
||||
import (
|
||||
"github.com/nose7en/ToyBoomServer/defs"
|
||||
"github.com/nose7en/ToyBoomServer/storage"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type Query interface {
|
||||
GetUserByAppleUserID(appleUserID string) (defs.UserGettable, error)
|
||||
GetUserByID(userID int64) (defs.UserGettable, error)
|
||||
}
|
||||
|
||||
type Mutation interface {
|
||||
CreateUser(user defs.UserGettable) error
|
||||
FirstOrCreateUser(user defs.UserGettable) (defs.UserGettable, error)
|
||||
}
|
||||
|
||||
var _ Query = (*queryImpl)(nil)
|
||||
var _ Mutation = (*mutationImpl)(nil)
|
||||
|
||||
type queryImpl struct{}
|
||||
type mutationImpl struct{}
|
||||
type queryImpl struct{ db *gorm.DB }
|
||||
|
||||
type mutationImpl struct{ db *gorm.DB }
|
||||
|
||||
func NewQuery() Query {
|
||||
return &queryImpl{}
|
||||
return &queryImpl{db: storage.GetDBManager().GetDefaultDB()}
|
||||
}
|
||||
|
||||
func NewMutation() Mutation {
|
||||
return &mutationImpl{}
|
||||
return &mutationImpl{db: storage.GetDBManager().GetDefaultDB()}
|
||||
}
|
||||
|
32
dao/user.go
32
dao/user.go
@ -1,13 +1,33 @@
|
||||
package dao
|
||||
|
||||
import "github.com/nose7en/ToyBoomServer/defs"
|
||||
import (
|
||||
"github.com/nose7en/ToyBoomServer/defs"
|
||||
"github.com/nose7en/ToyBoomServer/models"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func (q *queryImpl) GetUserByAppleUserID(appleUserID string) (defs.UserGettable, error) {
|
||||
|
||||
return nil, nil
|
||||
user := &models.User{}
|
||||
if result := q.db.Where(models.User{AppleUserID: appleUserID}).First(user); result.Error != nil {
|
||||
return nil, result.Error
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (m *mutationImpl) CreateUser(user defs.UserGettable) error {
|
||||
|
||||
return nil
|
||||
func (q *queryImpl) GetUserByID(userID int64) (defs.UserGettable, error) {
|
||||
user := &models.User{}
|
||||
if result := q.db.Where(models.User{Model: gorm.Model{ID: uint(userID)}}).First(user); result.Error != nil {
|
||||
return nil, result.Error
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (m *mutationImpl) FirstOrCreateUser(user defs.UserGettable) (defs.UserGettable, error) {
|
||||
record := &models.User{}
|
||||
newUserAttrs := &models.User{}
|
||||
newUserAttrs.FromUserInfo(user)
|
||||
result := m.db.Where(models.User{AppleUserID: user.GetAppleUserID()}).
|
||||
Attrs(newUserAttrs).
|
||||
FirstOrCreate(record)
|
||||
return record, result.Error
|
||||
}
|
||||
|
@ -1,7 +1,8 @@
|
||||
package defs
|
||||
|
||||
type User struct {
|
||||
UserID string `json:"user_id"`
|
||||
ID int64 `json:"user_id"`
|
||||
AppleUserID string `json:"apple_user_id"`
|
||||
Name string `json:"name"`
|
||||
Username string `json:"username"`
|
||||
Email string `json:"email"`
|
||||
|
@ -1,11 +1,16 @@
|
||||
package defs
|
||||
|
||||
import "github.com/golang-jwt/jwt/v5"
|
||||
|
||||
type UserGettable interface {
|
||||
GetUserID() string
|
||||
GetUserID() int64
|
||||
GetAppleUserID() string
|
||||
GetName() string
|
||||
GetUsername() string
|
||||
GetEmail() string
|
||||
GetIsPrivateEmail() bool
|
||||
GetEmailVerified() bool
|
||||
FromJWTClaims(jwt.MapClaims)
|
||||
ToUser() User
|
||||
ToMap() map[string]string
|
||||
}
|
||||
|
2
go.mod
2
go.mod
@ -63,7 +63,7 @@ require (
|
||||
github.com/sagikazarmark/locafero v0.4.0 // indirect
|
||||
github.com/sagikazarmark/slog-shim v0.1.0 // indirect
|
||||
github.com/spf13/afero v1.11.0 // indirect
|
||||
github.com/spf13/cast v1.6.0 // indirect
|
||||
github.com/spf13/cast v1.6.0
|
||||
github.com/spf13/pflag v1.0.5 // indirect
|
||||
github.com/subosito/gotenv v1.6.0 // indirect
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
|
@ -1,19 +1,23 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/Timothylock/go-signin-with-apple/apple"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/nose7en/ToyBoomServer/common"
|
||||
"github.com/nose7en/ToyBoomServer/config"
|
||||
"github.com/nose7en/ToyBoomServer/defs"
|
||||
"github.com/nose7en/ToyBoomServer/models"
|
||||
"github.com/nose7en/ToyBoomServer/rpc"
|
||||
"github.com/nose7en/ToyBoomServer/utils"
|
||||
"github.com/spf13/cast"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func ValidateAppleAppToken() func(c *gin.Context) {
|
||||
func ValidateAppleAppLoginCode() func(c *gin.Context) {
|
||||
return func(c *gin.Context) {
|
||||
code := c.GetHeader(common.TokenKey)
|
||||
resp, err := rpc.GetManager().AppleCli().VerifyAppToken(c, code)
|
||||
@ -47,8 +51,8 @@ func ValidateAppleAppToken() func(c *gin.Context) {
|
||||
emailVerified := cast.ToBool((*claim)["email_verified"])
|
||||
isPrivateEmail := cast.ToBool((*claim)["is_private_email"])
|
||||
|
||||
userInfo := &defs.User{
|
||||
UserID: unique,
|
||||
userInfo := &models.User{
|
||||
AppleUserID: unique,
|
||||
Email: email,
|
||||
IsPrivateEmail: isPrivateEmail,
|
||||
EmailVerified: emailVerified,
|
||||
@ -57,3 +61,74 @@ func ValidateAppleAppToken() func(c *gin.Context) {
|
||||
c.Set(common.UserInfoKey, userInfo)
|
||||
}
|
||||
}
|
||||
|
||||
func ValidateToken() func(c *gin.Context) {
|
||||
return func(c *gin.Context) {
|
||||
tokens, err := GetAuthTokensFromAny(c)
|
||||
if err != nil {
|
||||
common.Logger(c).WithError(err).Error("failed to get auth token")
|
||||
c.AbortWithStatusJSON(http.StatusOK, common.UnAuth("failed to get auth token"))
|
||||
return
|
||||
}
|
||||
|
||||
var clms jwt.MapClaims
|
||||
for tokenType, tokenStr := range tokens {
|
||||
clms, err = validateToken(c, tokenStr)
|
||||
if err == nil {
|
||||
common.Logger(c).Infof("jwt middleware parse token success, token type: %s", tokenType)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
common.Logger(c).WithError(err).Errorf("jwt middleware parse token error")
|
||||
c.JSON(http.StatusOK, common.UnAuth("invalid authorization"))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
userInfo := &models.User{}
|
||||
userInfo.FromJWTClaims(clms)
|
||||
if userInfo.GetUserID() <= 0 {
|
||||
common.Logger(c).Errorf("failed to build user info from token")
|
||||
c.AbortWithStatusJSON(http.StatusOK, common.UnAuth("invalid authorization"))
|
||||
return
|
||||
}
|
||||
|
||||
common.Logger(c).Infof("token auth success, user info: %+v", userInfo)
|
||||
c.Set(common.UserInfoKey, userInfo)
|
||||
}
|
||||
}
|
||||
|
||||
func validateToken(ctx *gin.Context, tokenStr string) (u jwt.MapClaims, err error) {
|
||||
if tokenStr == "" {
|
||||
return nil, errors.New("token_str is empty")
|
||||
}
|
||||
|
||||
if t, err := utils.ParseToken(config.GetSettings().JWTConfig.Secret, tokenStr); err == nil {
|
||||
for k, v := range t {
|
||||
ctx.Set(k, v)
|
||||
}
|
||||
ctx.Set(common.TokenKey, tokenStr)
|
||||
return t, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
func GetAuthTokensFromAny(c *gin.Context) (map[string]string, error) {
|
||||
ans := map[string]string{}
|
||||
if headerTokenStr := c.Request.Header.Get(common.AuthorizationKey); len(headerTokenStr) > 0 {
|
||||
headerToken := strings.Split(headerTokenStr, " ")
|
||||
if len(headerToken) == 2 {
|
||||
ans[common.TokenKey] = headerToken[1]
|
||||
} else {
|
||||
ans[common.TokenKey] = headerTokenStr
|
||||
}
|
||||
}
|
||||
|
||||
if len(ans) == 0 {
|
||||
return nil, errors.New("auth token is empty")
|
||||
}
|
||||
|
||||
return ans, nil
|
||||
}
|
||||
|
25
middleware/trace.go
Normal file
25
middleware/trace.go
Normal file
@ -0,0 +1,25 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/nose7en/ToyBoomServer/common"
|
||||
"github.com/nose7en/ToyBoomServer/utils"
|
||||
)
|
||||
|
||||
func Trace() func(c *gin.Context) {
|
||||
return func(c *gin.Context) {
|
||||
UpStreamTraceID := c.GetHeader(common.TraceIDKey)
|
||||
if len(UpStreamTraceID) > 0 {
|
||||
c.Set(common.TraceIDKey, UpStreamTraceID)
|
||||
common.Logger(c).Infof("get a common request, upstream traceid %v",
|
||||
c.GetString(common.TraceIDKey))
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
c.Set(common.TraceIDKey, utils.GenerateUID())
|
||||
common.Logger(c).Infof("get a common request, gen traceid %v",
|
||||
c.GetString(common.TraceIDKey))
|
||||
c.Next()
|
||||
}
|
||||
}
|
@ -1,7 +1,11 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/nose7en/ToyBoomServer/defs"
|
||||
"github.com/spf13/cast"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
@ -17,7 +21,7 @@ func NewUserGettable(opt ...func(*User)) defs.UserGettable {
|
||||
|
||||
type User struct {
|
||||
gorm.Model
|
||||
UserID string `json:"user_id"` // user id from apple
|
||||
AppleUserID string `json:"apple_user_id" gorm:"uniqueIndex"` // user id from apple
|
||||
Name string `json:"name"` // ToyBoom's user name
|
||||
Username string `json:"username"` // user name from apple
|
||||
Email string `json:"email"` // email from apple
|
||||
@ -25,8 +29,8 @@ type User struct {
|
||||
EmailVerified bool `json:"email_verified"`
|
||||
}
|
||||
|
||||
func (u *User) FillWithUserInfo(userInfo defs.UserGettable) {
|
||||
u.UserID = userInfo.GetUserID()
|
||||
func (u *User) FromUserInfo(userInfo defs.UserGettable) {
|
||||
u.AppleUserID = userInfo.GetAppleUserID()
|
||||
u.Name = userInfo.GetName()
|
||||
u.Username = userInfo.GetUsername()
|
||||
u.Email = userInfo.GetEmail()
|
||||
@ -34,6 +38,28 @@ func (u *User) FillWithUserInfo(userInfo defs.UserGettable) {
|
||||
u.EmailVerified = userInfo.GetEmailVerified()
|
||||
}
|
||||
|
||||
func (u *User) FromJWTClaims(claims jwt.MapClaims) {
|
||||
u.ID = cast.ToUint(claims["user_id"])
|
||||
u.AppleUserID = cast.ToString(claims["apple_user_id"])
|
||||
u.Name = cast.ToString(claims["name"])
|
||||
u.Username = cast.ToString(claims["username"])
|
||||
u.Email = cast.ToString(claims["email"])
|
||||
u.IsPrivateEmail = cast.ToBool(claims["is_private_email"])
|
||||
u.EmailVerified = cast.ToBool(claims["email_verified"])
|
||||
}
|
||||
|
||||
func (u *User) ToMap() map[string]string {
|
||||
return map[string]string{
|
||||
"user_id": fmt.Sprint(u.ID),
|
||||
"apple_user_id": u.AppleUserID,
|
||||
"name": u.Name,
|
||||
"username": u.Username,
|
||||
"email": u.Email,
|
||||
"is_private_email": fmt.Sprint(u.IsPrivateEmail),
|
||||
"email_verified": fmt.Sprint(u.EmailVerified),
|
||||
}
|
||||
}
|
||||
|
||||
func (u *User) GetName() string {
|
||||
return u.Name
|
||||
}
|
||||
@ -42,8 +68,12 @@ func (u *User) GetUsername() string {
|
||||
return u.Username
|
||||
}
|
||||
|
||||
func (u *User) GetUserID() string {
|
||||
return u.UserID
|
||||
func (u *User) GetAppleUserID() string {
|
||||
return u.AppleUserID
|
||||
}
|
||||
|
||||
func (u *User) GetUserID() int64 {
|
||||
return int64(u.ID)
|
||||
}
|
||||
|
||||
func (u *User) GetEmail() string {
|
||||
@ -60,7 +90,8 @@ func (u *User) GetEmailVerified() bool {
|
||||
|
||||
func (u *User) ToUser() defs.User {
|
||||
return defs.User{
|
||||
UserID: u.UserID,
|
||||
ID: int64(u.ID),
|
||||
AppleUserID: u.AppleUserID,
|
||||
Name: u.Name,
|
||||
Username: u.Username,
|
||||
Email: u.Email,
|
||||
|
@ -1,6 +1,9 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/nose7en/ToyBoomServer/common"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
@ -24,7 +27,9 @@ type dbManagerImpl struct {
|
||||
func (dbm *dbManagerImpl) Init(tables ...interface{}) {
|
||||
dbs := dbm.DBs[DefaultDBName]
|
||||
for _, db := range dbs {
|
||||
db.AutoMigrate(tables...)
|
||||
if err := db.AutoMigrate(tables...); err != nil {
|
||||
common.Logger(context.Background()).WithError(err).Fatalf("auto migrate error!!!")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user