package handler import ( "errors" "github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2/middleware/session" "github.com/uptrace/bun" "go.uber.org/zap" "omnibill.net/omnibill/models" "omnibill.net/omnibill/web/utils" ) func GetUserID(logger *zap.Logger, sess *session.Session, ctx *fiber.Ctx) (string, error) { userID := sess.Get("uid").(string) if userID == "" { return "", DestroySession(logger, sess, ctx, "/auth/login") } return userID, nil } func GetUser(logger *zap.Logger, sess *session.Session, db *bun.DB, ctx *fiber.Ctx) (*models.User, error) { userID, err := GetUserID(logger, sess, ctx) if err != nil { return nil, err } user := new(models.User) if err := db.NewSelect().Model(&user).Where("id = ?", userID).Scan(ctx.UserContext()); err != nil { logger.Error("error getting columns", zap.Error(err)) sessErr := DestroySession(logger, sess, ctx, "/auth/login") return nil, errors.Join(err, sessErr) } return user, nil } // Register registers the handler with the webserver. // It takes a handler as it's input. // If the handler does not implement any of the handler interfaces, it will panic. func Register(handler any) { var hasHandler bool if _, ok := handler.(utils.GET); ok { hasHandler = true } if _, ok := handler.(utils.POST); ok { hasHandler = true } if _, ok := handler.(utils.PUT); ok { hasHandler = true } if _, ok := handler.(utils.DELETE); ok { hasHandler = true } if _, ok := handler.(utils.PATCH); ok { hasHandler = true } if _, ok := handler.(utils.HEAD); ok { hasHandler = true } if _, ok := handler.(utils.OPTIONS); ok { hasHandler = true } if !hasHandler { panic("Invalid handler") } utils.Handlers = append(utils.Handlers, handler) } // DestroySession destroys any Session provided to it. // It takes: zap.Logger, session.Session, fiber.Ctx, and an optional redirectPath of type string. // The function returns a redirect with an error, or just an error, if any. func DestroySession(logger *zap.Logger, sess *session.Session, ctx *fiber.Ctx, redirectPath ...string) error { if err := sess.Destroy(); err != nil { logger.Error("error destroying session", zap.Error(err)) return fiber.ErrInternalServerError } if err := sess.Save(); err != nil { logger.Error("error saving session", zap.Error(err)) return fiber.ErrInternalServerError } if redirectPath != nil && len(redirectPath) > 0 { return ctx.Redirect(redirectPath[0]) } else { return nil } }