51 lines
1.3 KiB
Go
51 lines
1.3 KiB
Go
|
package middleware
|
||
|
|
||
|
import (
|
||
|
"github.com/gofiber/fiber/v2"
|
||
|
"github.com/gofiber/fiber/v2/middleware/session"
|
||
|
"github.com/uptrace/bun"
|
||
|
"go.uber.org/zap"
|
||
|
"omnibill.net/omnibill/models"
|
||
|
"reflect"
|
||
|
)
|
||
|
|
||
|
func Auth(logger *zap.Logger, db *bun.DB, authSessionStore *session.Store, handler interface{}) fiber.Handler {
|
||
|
return func(c *fiber.Ctx) error {
|
||
|
if !c.IsProxyTrusted() {
|
||
|
return fiber.ErrUnauthorized
|
||
|
}
|
||
|
|
||
|
authSession, err := authSessionStore.Get(c)
|
||
|
if err != nil {
|
||
|
return fiber.ErrUnauthorized
|
||
|
}
|
||
|
|
||
|
if len(authSession.Keys()) == 0 {
|
||
|
return fiber.ErrUnauthorized
|
||
|
}
|
||
|
|
||
|
var user models.User
|
||
|
userID := authSession.Get("uid").(string)
|
||
|
keyCount, err := db.NewSelect().Model(&user).Where("id = ?", userID).Count(c.UserContext())
|
||
|
if err != nil {
|
||
|
logger.Error("error getting columns", zap.Error(err))
|
||
|
return fiber.ErrInternalServerError
|
||
|
}
|
||
|
|
||
|
if keyCount == 0 {
|
||
|
if err := authSession.Destroy(); err != nil {
|
||
|
logger.Error("error destroying session", zap.Error(err))
|
||
|
return fiber.ErrInternalServerError
|
||
|
}
|
||
|
if err := authSession.Save(); err != nil {
|
||
|
logger.Error("error saving session", zap.Error(err))
|
||
|
return fiber.ErrInternalServerError
|
||
|
}
|
||
|
return fiber.ErrUnauthorized
|
||
|
}
|
||
|
|
||
|
reflect.ValueOf(handler).Elem().FieldByName("AuthSession").Set(reflect.ValueOf(authSession))
|
||
|
return nil
|
||
|
}
|
||
|
}
|