package handlers

import (
	"context"
	"errors"
	"fmt"
	"net/http"
	"os"
	"time"

	"github.com/gin-gonic/gin"
	"github.com/golang-jwt/jwt/v4"
	"go.mongodb.org/mongo-driver/bson"
	"go.mongodb.org/mongo-driver/bson/primitive"
	"go.mongodb.org/mongo-driver/mongo"
	"golang.org/x/crypto/bcrypt"
)

// Verify jwt token string. Return the payload username if verified.
func verifyJWT(tokenString string) (string, error) {
	jwtsecret := []byte(os.Getenv("JWTSECRET"))
	token, err := jwt.ParseWithClaims(tokenString, &usernameClaims{}, func(token *jwt.Token) (interface{}, error) {
		// Don't forget to validate the alg is what you expect:
		if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
			return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
		}

		// hmacSampleSecret is a []byte containing your secret, e.g. []byte("my_secret_key")
		return jwtsecret, nil
	})
	if err != nil {
		return "", err
	}
	if claims, ok := token.Claims.(*usernameClaims); ok && token.Valid {
		return claims.Name, nil
	} else {
		return "", errors.New("invalid token")
	}
}

func GetToken(mongoc *mongo.Client) gin.HandlerFunc {
	fn := func(ctx *gin.Context) {
		var tokenauth tokenAuthReq
		if err := ctx.BindJSON(&tokenauth); err != nil {
			ctx.JSON(http.StatusBadRequest, gin.H{
				"message": err.Error(),
				"success": false,
				"payload": tokenauth,
			})
			return
		}
		dbname := os.Getenv("DB_NAME")
		coll := mongoc.Database(dbname).Collection("users")
		var userdoc bson.M
		mctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
		defer cancel()
		if err := coll.FindOne(mctx, bson.D{{Key: "name", Value: tokenauth.Name}}).Decode(&userdoc); err != nil {
			fmt.Println(err)
			ctx.JSON(http.StatusNotFound, gin.H{
				"message": err.Error(),
				"success": false,
				"payload": tokenauth,
			})
			return
		}
		// fmt.Println("userdoc:", userdoc)
		hashedPass, ok := userdoc["password"].(string)
		if !ok {
			ctx.JSON(http.StatusUnauthorized, gin.H{
				"message": "Unauthorized",
				"success": false,
			})
			return
		}

		if err := bcrypt.CompareHashAndPassword([]byte(hashedPass), []byte(tokenauth.Password)); err != nil {
			ctx.JSON(http.StatusUnauthorized, gin.H{
				"message": "Unauthorized",
				"success": false,
			})
			return
		}

		claims := usernameClaims{
			tokenauth.Name,
			jwt.RegisteredClaims{
				ExpiresAt: jwt.NewNumericDate(time.Now().Add(24 * 7 * time.Hour)),
			},
		}
		token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
		jwtsecret := []byte(os.Getenv("JWTSECRET"))
		tokenstring, err := token.SignedString(jwtsecret)
		if err != nil {
			ctx.JSON(http.StatusInternalServerError, gin.H{
				"message": err.Error(),
				"success": false,
			})
			return
		}
		tokenreply := accesstoken{
			Message: "Token generated",
			Success: true,
			Token:   tokenstring,
		}
		ctx.JSON(http.StatusOK, tokenreply)
	}
	return gin.HandlerFunc(fn)
}

func RegisterHandler(mongoc *mongo.Client) gin.HandlerFunc {
	fn := func(ctx *gin.Context) {
		var regreq registerReq
		if err := ctx.BindJSON(&regreq); err != nil {
			ctx.JSON(http.StatusBadRequest, gin.H{
				"message": err.Error(),
				"success": false,
				"payload": regreq,
			})
			return
		}
		regcode := os.Getenv("REGCODE")
		if regreq.Regcode != regcode {
			ctx.JSON(http.StatusBadRequest, gin.H{
				"message": "Invalid regcode",
				"success": false,
			})
			return
		}
		dbname := os.Getenv("DB_NAME")
		coll := mongoc.Database(dbname).Collection("users")
		var userdoc bson.M
		mctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
		defer cancel()
		if err := coll.FindOne(mctx, bson.D{{Key: "name", Value: regreq.Username}}).Decode(&userdoc); err == nil {
			ctx.JSON(http.StatusBadRequest, gin.H{
				"message": "User already exists",
				"success": false,
				"payload": regreq,
			})
			return
		} else if err != mongo.ErrNoDocuments {
			fmt.Println(err)
			ctx.JSON(http.StatusInternalServerError, gin.H{
				"message": err.Error(),
				"success": false,
				"payload": regreq,
			})
			return
		}
		hashedPass, hasherr := bcrypt.GenerateFromPassword([]byte(regreq.Password), 10)

		if hasherr != nil {
			fmt.Println(hasherr)
			ctx.JSON(http.StatusInternalServerError, gin.H{
				"message": hasherr,
				"success": false,
				"payload": regreq,
			})
			return
		}

		reguser := userAccount{
			Name:     regreq.Username,
			Email:    regreq.Email,
			Corp:     regreq.Username,
			Password: string(hashedPass),
		}

		if err := bcrypt.CompareHashAndPassword(hashedPass, []byte(regreq.Password)); err != nil {
			ctx.JSON(http.StatusUnauthorized, gin.H{
				"message": "Hash failed",
				"success": false,
			})
			return
		}

		// add user
		if _, err := coll.InsertOne(mctx, reguser); err != nil {
			ctx.JSON(http.StatusInternalServerError, gin.H{
				"message": "Failed to add new user:" + err.Error(),
				"success": false,
				"payload": regreq,
			})
			return
		}

		claims := usernameClaims{
			regreq.Username,
			jwt.RegisteredClaims{
				ExpiresAt: jwt.NewNumericDate(time.Now().Add(24 * 7 * time.Hour)),
			},
		}
		token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
		jwtsecret := []byte(os.Getenv("JWTSECRET"))
		tokenstring, err := token.SignedString(jwtsecret)
		if err != nil {
			ctx.JSON(http.StatusInternalServerError, gin.H{
				"message": err.Error(),
				"success": false,
			})
			return
		}
		tokenreply := accesstoken{
			Message: "User registered",
			Success: true,
			Token:   tokenstring,
		}
		ctx.JSON(http.StatusOK, tokenreply)
	}
	return gin.HandlerFunc(fn)
}

// devUserAuth returns device objectid and user objectid
// if tokenString and device user are both authorized.
func devUserAuth(mongoc *mongo.Client, tokenString string, device string) (primitive.ObjectID, primitive.ObjectID, error) {
	tokenuser, parseerr := verifyJWT(tokenString)
	if parseerr != nil {
		return primitive.NilObjectID, primitive.NilObjectID, parseerr
	}

	// Query device doc
	dbname := os.Getenv("DB_NAME")
	colldevice := mongoc.Database(dbname).Collection("devices")
	mctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
	defer cancel()

	var devicedoc bson.M
	pipeline := mongo.Pipeline{
		{{Key: "$match", Value: bson.M{"deviceid": device}}},
		{{Key: "$lookup", Value: bson.M{
			"from":         "users",
			"localField":   "user",
			"foreignField": "_id",
			"as":           "userdoc",
		}}},
	}

	cursor, aggerr := colldevice.Aggregate(mctx, pipeline)
	if aggerr != nil {
		return primitive.NilObjectID, primitive.NilObjectID, aggerr
	}

	if cursor.Next(mctx) {
		if err := cursor.Decode(&devicedoc); err != nil {
			return primitive.NilObjectID, primitive.NilObjectID, err
		}
	}

	// assert device belonging to jwt authenticated user
	if devuserdoc, dok := devicedoc["userdoc"].(primitive.A); dok {
		if du, duok := devuserdoc[0].(primitive.M); duok {
			if devuserdocname, nok := du["name"].(string); nok {
				if devuserdocname != tokenuser {
					return primitive.NilObjectID, primitive.NilObjectID, errors.New("this is not your device")
				}
			} else {
				return primitive.NilObjectID, primitive.NilObjectID, errors.New("no name in userdoc")
			}
		} else {
			return primitive.NilObjectID, primitive.NilObjectID, errors.New("userdoc is not an array")
		}
	} else {
		return primitive.NilObjectID, primitive.NilObjectID, errors.New("no userdoc in devicedoc")
	}

	devid, devidok := devicedoc["_id"].(primitive.ObjectID)
	if !devidok {
		return primitive.NilObjectID, primitive.NilObjectID, errors.New("no _id in devicedoc")
	}

	userid, useridok := devicedoc["user"].(primitive.ObjectID)
	if !useridok {
		return primitive.NilObjectID, primitive.NilObjectID, errors.New("no userid in devicedoc")
	}

	return devid, userid, nil
}
