89 lines
2.4 KiB
Go
89 lines
2.4 KiB
Go
package handler
|
|
|
|
import (
|
|
"errors"
|
|
"github.com/gofiber/fiber/v2"
|
|
"github.com/gofiber/fiber/v2/middleware/session"
|
|
"github.com/uptrace/bun"
|
|
"go.uber.org/zap"
|
|
"omnibill.net/omnibill/models"
|
|
"omnibill.net/omnibill/web/utils"
|
|
)
|
|
|
|
func GetUserID(logger *zap.Logger, sess *session.Session, ctx *fiber.Ctx) (string, error) {
|
|
userID := sess.Get("uid").(string)
|
|
if userID == "" {
|
|
return "", DestroySession(logger, sess, ctx, "/auth/login")
|
|
}
|
|
return userID, nil
|
|
}
|
|
|
|
func GetUser(logger *zap.Logger, sess *session.Session, db *bun.DB, ctx *fiber.Ctx) (*models.User, error) {
|
|
userID, err := GetUserID(logger, sess, ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
user := new(models.User)
|
|
if err := db.NewSelect().Model(&user).Where("id = ?", userID).Scan(ctx.UserContext()); err != nil {
|
|
logger.Error("error getting columns", zap.Error(err))
|
|
sessErr := DestroySession(logger, sess, ctx, "/auth/login")
|
|
return nil, errors.Join(err, sessErr)
|
|
}
|
|
|
|
return user, nil
|
|
}
|
|
|
|
// Register registers the handler with the webserver.
|
|
// It takes a handler as it's input.
|
|
// If the handler does not implement any of the handler interfaces, it will panic.
|
|
func Register(handler any) {
|
|
var hasHandler bool
|
|
if _, ok := handler.(utils.GET); ok {
|
|
hasHandler = true
|
|
}
|
|
if _, ok := handler.(utils.POST); ok {
|
|
hasHandler = true
|
|
}
|
|
if _, ok := handler.(utils.PUT); ok {
|
|
hasHandler = true
|
|
}
|
|
if _, ok := handler.(utils.DELETE); ok {
|
|
hasHandler = true
|
|
}
|
|
if _, ok := handler.(utils.PATCH); ok {
|
|
hasHandler = true
|
|
}
|
|
if _, ok := handler.(utils.HEAD); ok {
|
|
hasHandler = true
|
|
}
|
|
if _, ok := handler.(utils.OPTIONS); ok {
|
|
hasHandler = true
|
|
}
|
|
|
|
if !hasHandler {
|
|
panic("Invalid handler")
|
|
}
|
|
|
|
utils.Handlers = append(utils.Handlers, handler)
|
|
}
|
|
|
|
// DestroySession destroys any Session provided to it.
|
|
// It takes: zap.Logger, session.Session, fiber.Ctx, and an optional redirectPath of type string.
|
|
// The function returns a redirect with an error, or just an error, if any.
|
|
func DestroySession(logger *zap.Logger, sess *session.Session, ctx *fiber.Ctx, redirectPath ...string) error {
|
|
if err := sess.Destroy(); err != nil {
|
|
logger.Error("error destroying session", zap.Error(err))
|
|
return fiber.ErrInternalServerError
|
|
}
|
|
if err := sess.Save(); err != nil {
|
|
logger.Error("error saving session", zap.Error(err))
|
|
return fiber.ErrInternalServerError
|
|
}
|
|
|
|
if redirectPath != nil && len(redirectPath) > 0 {
|
|
return ctx.Redirect(redirectPath[0])
|
|
} else {
|
|
return nil
|
|
}
|
|
}
|