package main

import (
	"context"
	"crypto/tls"
	"encoding/json"
	"fmt"
	"io"
	"log/slog"
	"net/http"
	"net/url"
	"os"
	"os/exec"
	"os/user"
	"strconv"
	"strings"
	"sync"
	"sync/atomic"
	"syscall"
	"time"

	"github.com/creack/pty"
	"github.com/gorilla/websocket"
	"github.com/gpuaas/platform/packages/shared/buildinfo"
)

type terminalWSControlFrame struct {
	Type            string `json:"type"`
	CloseReason     string `json:"close_reason,omitempty"`
	ErrorCode       string `json:"error_code,omitempty"`
	ErrorMessage    string `json:"error_message,omitempty"`
	RuntimeUsername string `json:"runtime_username,omitempty"`
	Cols            *int   `json:"cols,omitempty"`
	Rows            *int   `json:"rows,omitempty"`
}

type terminalSession struct {
	in     terminalOpenInput
	cancel context.CancelFunc
	done   chan struct{}
}

type terminalSessionOutcome struct {
	protocolReason string
	initiator      string
	err            error
}

type terminalStartupResult struct {
	err error
}

var terminalSessions sync.Map      // session_id => *terminalSession
var terminalAllocSessions sync.Map // allocation_id => session_id
var terminalSessionsStarted atomic.Uint64
var lookupTerminalUser = user.Lookup
var terminalEUID = os.Geteuid

const maxUint16Int = int(^uint16(0))
const maxUint32Int = int(^uint32(0))

func runTerminalOpenSession(in terminalOpenInput) error {
	slog.Info("terminal session start requested", "session_id", in.SessionID, "allocation_id", in.AllocationID, "username", in.Username)
	terminalSessionsStarted.Add(1)
	if existing, ok := terminalSessions.Load(in.SessionID); ok {
		if sess, cast := existing.(*terminalSession); cast {
			sess.cancel()
		}
	}
	ctx, cancel := context.WithCancel(context.Background())
	sess := &terminalSession{
		in:     in,
		cancel: cancel,
		done:   make(chan struct{}),
	}
	terminalSessions.Store(in.SessionID, sess)
	terminalAllocSessions.Store(in.AllocationID, in.SessionID)
	startup := make(chan terminalStartupResult, 1)
	go func() {
		defer close(sess.done)
		defer terminalSessions.Delete(in.SessionID)
		if mapped, ok := terminalAllocSessions.Load(in.AllocationID); ok && mapped == in.SessionID {
			terminalAllocSessions.Delete(in.AllocationID)
		}
		startedAt := time.Now()
		outcome := streamPTYSession(ctx, in, startup)
		duration := time.Since(startedAt)
		slog.Info("terminal session closed",
			"session_id", in.SessionID,
			"allocation_id", in.AllocationID,
			"username", in.Username,
			"close_reason", strings.TrimSpace(outcome.protocolReason),
			"initiator", strings.TrimSpace(outcome.initiator),
			"duration_ms", duration.Milliseconds(),
		)
		if outcome.err != nil {
			slog.Warn("terminal session ended with error",
				"session_id", in.SessionID,
				"allocation_id", in.AllocationID,
				"close_reason", strings.TrimSpace(outcome.protocolReason),
				"initiator", strings.TrimSpace(outcome.initiator),
				"error", outcome.err,
			)
		}
	}()
	select {
	case result := <-startup:
		if result.err != nil {
			cancel()
			<-sess.done
			return result.err
		}
		return nil
	case <-time.After(terminalOpenStartupTimeout()):
		cancel()
		return fmt.Errorf("terminal startup timed out before node stream was ready")
	}
}

func terminalSessionsStartedTotal() uint64 {
	return terminalSessionsStarted.Load()
}

func terminalSessionsActiveTotal() uint64 {
	var count uint64
	terminalSessions.Range(func(_, _ any) bool {
		count++
		return true
	})
	return count
}

func closeTerminalSessionByAllocation(allocationID, reason string) bool {
	raw, ok := terminalAllocSessions.Load(allocationID)
	if !ok {
		return false
	}
	sessionID, _ := raw.(string)
	if sessionID == "" {
		return false
	}
	rawSession, ok := terminalSessions.Load(sessionID)
	if !ok {
		return false
	}
	sess, ok := rawSession.(*terminalSession)
	if !ok {
		return false
	}
	slog.Info("closing terminal session", "allocation_id", allocationID, "session_id", sessionID, "reason", strings.TrimSpace(reason))
	sess.cancel()
	return true
}

func streamPTYSession(ctx context.Context, in terminalOpenInput, startup chan<- terminalStartupResult) terminalSessionOutcome {
	return streamPTYSessionInternalWS(ctx, in, startup)
}

func streamPTYSessionInternalWS(ctx context.Context, in terminalOpenInput, startup chan<- terminalStartupResult) terminalSessionOutcome {
	apiURL := strings.TrimRight(strings.TrimSpace(envOrDefault("GPUAAS_TERMINAL_API_URL", os.Getenv("GPUAAS_API_URL"))), "/")
	nodeID := strings.TrimSpace(os.Getenv("GPUAAS_NODE_ID"))
	certPath := envOrDefault("GPUAAS_CERT_PATH", ".data/node-agent/cert.pem")
	keyPath := envOrDefault("GPUAAS_KEY_PATH", ".data/node-agent/key.pem")
	caBundlePath := envOrDefault("GPUAAS_CA_BUNDLE_PATH", ".data/node-agent/ca-bundle.crt")
	if apiURL == "" || nodeID == "" {
		err := fmt.Errorf("terminal websocket env not configured")
		reportTerminalStartup(startup, err)
		return terminalSessionOutcome{protocolReason: "open_failed", initiator: "node_agent", err: err}
	}

	wsURL, err := terminalInternalWSURL(apiURL, in.SessionID)
	if err != nil {
		reportTerminalStartup(startup, err)
		return terminalSessionOutcome{protocolReason: "open_failed", initiator: "node_agent", err: err}
	}
	dialer, err := nodeMTLSDialer(certPath, keyPath, caBundlePath)
	if err != nil {
		reportTerminalStartup(startup, err)
		return terminalSessionOutcome{protocolReason: "open_failed", initiator: "node_agent", err: err}
	}

	nodeInstanceID := strings.TrimSpace(envOrDefault("GPUAAS_NODE_INSTANCE_ID", ""))
	headers := http.Header{}
	headers.Set("User-Agent", fmt.Sprintf("gpuaas-node-agent/%s pid=%d instance=%s", buildinfo.Version, os.Getpid(), nodeInstanceID))
	slog.Info("terminal websocket request starting",
		"session_id", in.SessionID,
		"allocation_id", in.AllocationID,
		"node_id", nodeID,
		"node_instance_id", nodeInstanceID,
		"api_url", wsURL,
	)
	wsConn, resp, err := dialer.DialContext(ctx, wsURL, headers)
	if err != nil {
		closeTerminalHandshakeResponse(resp)
		reportTerminalStartup(startup, err)
		return terminalSessionOutcome{protocolReason: "open_failed", initiator: "node_agent", err: err}
	}
	defer func() { _ = wsConn.Close() }()
	slog.Info("terminal websocket connected",
		"session_id", in.SessionID,
		"allocation_id", in.AllocationID,
		"node_id", nodeID,
		"username", in.Username,
		"capacity_shape", in.CapacityShape,
	)

	cmd, err := terminalCommandForInput(ctx, in)
	if err != nil {
		_ = writeTerminalWSControl(wsConn, terminalWSControlFrame{
			Type:            "close",
			CloseReason:     "open_failed",
			ErrorCode:       classifyTerminalOpenError(err),
			ErrorMessage:    "terminal session failed to open on node",
			RuntimeUsername: in.Username,
		})
		reportTerminalStartup(startup, err)
		return terminalSessionOutcome{protocolReason: "open_failed", initiator: "node_agent", err: err}
	}
	ptmx, err := pty.Start(cmd)
	if err != nil {
		startErr := fmt.Errorf("start terminal shell: %w", err)
		_ = writeTerminalWSControl(wsConn, terminalWSControlFrame{
			Type:            "close",
			CloseReason:     "open_failed",
			ErrorCode:       classifyTerminalOpenError(startErr),
			ErrorMessage:    "terminal session failed to open on node",
			RuntimeUsername: in.Username,
		})
		reportTerminalStartup(startup, startErr)
		return terminalSessionOutcome{
			protocolReason: "open_failed",
			initiator:      "node_agent",
			err:            startErr,
		}
	}
	defer func() { _ = ptmx.Close() }()
	slog.Info("terminal pty started",
		"session_id", in.SessionID,
		"allocation_id", in.AllocationID,
		"username", in.Username,
		"command", strings.Join(cmd.Args, " "),
	)
	reportTerminalStartup(startup, nil)

	if in.Cols > 0 && in.Rows > 0 {
		if winsize, ok := safePTYWinsize(in.Cols, in.Rows); ok {
			_ = pty.Setsize(ptmx, winsize)
		}
	}

	outcomeCh := make(chan terminalSessionOutcome, 4)
	done := make(chan struct{})
	var wsWriteMu sync.Mutex

	go func() {
		defer close(done)
		buf := make([]byte, 4096)
		loggedFirstData := false
		for {
			n, readErr := ptmx.Read(buf)
			if n > 0 {
				if !loggedFirstData {
					loggedFirstData = true
					slog.Info("terminal pty produced first output",
						"session_id", in.SessionID,
						"allocation_id", in.AllocationID,
						"username", in.Username,
						"bytes", n,
						"preview", terminalLogPreview(buf[:n]),
					)
				}
				wsWriteMu.Lock()
				err := wsConn.WriteMessage(websocket.BinaryMessage, append([]byte(nil), buf[:n]...))
				wsWriteMu.Unlock()
				if err != nil {
					outcomeCh <- terminalSessionOutcome{protocolReason: "node_stream_dropped", initiator: "node_agent", err: err}
					return
				}
			}
			if readErr != nil {
				if !loggedFirstData {
					slog.Info("terminal pty ended before output",
						"session_id", in.SessionID,
						"allocation_id", in.AllocationID,
						"username", in.Username,
						"error", readErr,
					)
				}
				if readErr == io.EOF {
					outcomeCh <- terminalSessionOutcome{protocolReason: "normal_close", initiator: "pty", err: nil}
					return
				}
				outcomeCh <- terminalSessionOutcome{protocolReason: "node_stream_dropped", initiator: "pty", err: readErr}
				return
			}
		}
	}()

	go func() {
		for {
			msgType, payload, err := wsConn.ReadMessage()
			if err != nil {
				outcomeCh <- terminalSessionOutcome{protocolReason: "normal_close", initiator: "gateway", err: nil}
				return
			}
			switch msgType {
			case websocket.BinaryMessage:
				if _, err := ptmx.Write(payload); err != nil {
					outcomeCh <- terminalSessionOutcome{protocolReason: "node_stream_dropped", initiator: "pty", err: err}
					return
				}
			case websocket.TextMessage:
				var frame terminalWSControlFrame
				if err := json.Unmarshal(payload, &frame); err != nil {
					outcomeCh <- terminalSessionOutcome{protocolReason: "node_stream_dropped", initiator: "gateway", err: err}
					return
				}
				switch strings.TrimSpace(strings.ToLower(frame.Type)) {
				case "resize":
					if frame.Cols != nil && frame.Rows != nil && *frame.Cols > 0 && *frame.Rows > 0 {
						if winsize, ok := safePTYWinsize(*frame.Cols, *frame.Rows); ok {
							_ = pty.Setsize(ptmx, winsize)
						}
					}
				case "close":
					outcomeCh <- terminalSessionOutcome{
						protocolReason: normalizeTerminalProtocolReason(frame.CloseReason),
						initiator:      "gateway",
						err:            nil,
					}
					return
				}
			}
		}
	}()

	go func() {
		<-ctx.Done()
		outcomeCh <- terminalSessionOutcome{protocolReason: "normal_close", initiator: "node_agent_context", err: nil}
	}()

	select {
	case outcome := <-outcomeCh:
		wsWriteMu.Lock()
		_ = writeTerminalWSControl(wsConn, terminalWSControlFrame{
			Type:        "close",
			CloseReason: normalizeTerminalProtocolReason(outcome.protocolReason),
		})
		wsWriteMu.Unlock()
		_ = ptmx.Close()
		_ = cmd.Process.Kill()
		<-done
		return outcome
	case <-done:
		wsWriteMu.Lock()
		_ = writeTerminalWSControl(wsConn, terminalWSControlFrame{
			Type:        "close",
			CloseReason: "normal_close",
		})
		wsWriteMu.Unlock()
		_ = ptmx.Close()
		return terminalSessionOutcome{protocolReason: "normal_close", initiator: "pty", err: nil}
	}
}

func closeTerminalHandshakeResponse(resp *http.Response) {
	if resp == nil || resp.Body == nil {
		return
	}
	_ = resp.Body.Close()
}

func reportTerminalStartup(startup chan<- terminalStartupResult, err error) {
	if startup == nil {
		return
	}
	select {
	case startup <- terminalStartupResult{err: err}:
	default:
	}
}

func terminalOpenStartupTimeout() time.Duration {
	raw := strings.TrimSpace(os.Getenv("GPUAAS_TERMINAL_OPEN_START_TIMEOUT_SECONDS"))
	if raw == "" {
		return 20 * time.Second
	}
	seconds, err := strconv.Atoi(raw)
	if err != nil || seconds < 1 || seconds > 120 {
		return 20 * time.Second
	}
	return time.Duration(seconds) * time.Second
}

func normalizeTerminalProtocolReason(reason string) string {
	switch strings.TrimSpace(strings.ToLower(reason)) {
	case "allocation_released":
		return "allocation_released"
	case "session_timeout":
		return "session_timeout"
	case "node_stream_dropped":
		return "node_stream_dropped"
	case "open_failed":
		return "open_failed"
	case "normal_close":
		return "normal_close"
	default:
		return "normal_close"
	}
}

func terminalInternalWSURL(apiURL, sessionID string) (string, error) {
	if strings.TrimSpace(apiURL) == "" {
		return "", fmt.Errorf("terminal api url missing")
	}
	parsed, err := url.Parse(apiURL)
	if err != nil {
		return "", fmt.Errorf("parse terminal api url: %w", err)
	}
	switch parsed.Scheme {
	case "http":
		parsed.Scheme = "ws"
	case "https":
		parsed.Scheme = "wss"
	default:
		return "", fmt.Errorf("terminal api url must use http or https")
	}
	parsed.Path = "/internal/ws/terminal/" + strings.TrimSpace(sessionID)
	parsed.RawQuery = ""
	return parsed.String(), nil
}

func nodeMTLSDialer(certPath, keyPath, caBundlePath string) (*websocket.Dialer, error) {
	cert, err := tls.LoadX509KeyPair(certPath, keyPath)
	if err != nil {
		return nil, fmt.Errorf("load node cert/key: %w", err)
	}
	roots, err := loadRootCAs(caBundlePath)
	if err != nil {
		return nil, fmt.Errorf("load ca roots: %w", err)
	}
	return &websocket.Dialer{
		TLSClientConfig: &tls.Config{
			MinVersion:   tls.VersionTLS12,
			Certificates: []tls.Certificate{cert},
			RootCAs:      roots,
		},
		HandshakeTimeout: 15 * time.Second,
	}, nil
}

func writeTerminalWSControl(conn *websocket.Conn, frame terminalWSControlFrame) error {
	return conn.WriteJSON(frame)
}

func terminalCommandForInput(ctx context.Context, in terminalOpenInput) (*exec.Cmd, error) {
	if strings.TrimSpace(in.CapacityShape) == "gpu_slice" {
		return terminalSSHCommandForSlice(ctx, in.Username, in.TargetHost, in.TargetPort)
	}
	return terminalCommandForUser(ctx, in.Username)
}

func terminalCommandForUser(ctx context.Context, username string) (*exec.Cmd, error) {
	if terminalEUID() != 0 {
		return nil, fmt.Errorf("terminal runtime requires root execution context")
	}

	u, err := lookupTerminalUser(strings.TrimSpace(username))
	if err != nil {
		return nil, fmt.Errorf("lookup terminal user: %w", err)
	}
	uid, err := strconv.Atoi(strings.TrimSpace(u.Uid))
	if err != nil || uid < 0 {
		return nil, fmt.Errorf("invalid terminal uid")
	}
	gid, err := strconv.Atoi(strings.TrimSpace(u.Gid))
	if err != nil || gid < 0 {
		return nil, fmt.Errorf("invalid terminal gid")
	}

	shell := "/bin/bash"
	args := []string{"-i", "-l"}
	if _, err := os.Stat(shell); err != nil {
		shell = "/bin/sh"
		args = []string{"-i"}
	}
	cmd := exec.CommandContext(ctx, shell, args...)
	if home := strings.TrimSpace(u.HomeDir); home != "" {
		cmd.Dir = home
	}
	cmd.Env = append(os.Environ(),
		"HOME="+strings.TrimSpace(u.HomeDir),
		"USER="+strings.TrimSpace(u.Username),
		"LOGNAME="+strings.TrimSpace(u.Username),
	)
	if !containsEnvKey(cmd.Env, "TERM") {
		cmd.Env = append(cmd.Env, "TERM=xterm-256color")
	}
	if !containsEnvKey(cmd.Env, "COLORTERM") {
		cmd.Env = append(cmd.Env, "COLORTERM=truecolor")
	}
	if !containsEnvKey(cmd.Env, "SHELL") {
		cmd.Env = append(cmd.Env, "SHELL="+shell)
	}

	if uid > maxUint32Int {
		return nil, fmt.Errorf("invalid terminal uid")
	}
	if gid > maxUint32Int {
		return nil, fmt.Errorf("invalid terminal gid")
	}
	cred := &syscall.Credential{
		Uid: uint32(uid),
		Gid: uint32(gid),
	}
	groupIDs, _ := u.GroupIds()
	if len(groupIDs) > 0 {
		groups := make([]uint32, 0, len(groupIDs))
		for _, raw := range groupIDs {
			parsed, err := strconv.ParseUint(strings.TrimSpace(raw), 10, 32)
			if err != nil {
				continue
			}
			groups = append(groups, uint32(parsed))
		}
		if len(groups) > 0 {
			cred.Groups = groups
		}
	}
	cmd.SysProcAttr = &syscall.SysProcAttr{Credential: cred}
	return cmd, nil
}

func terminalSSHCommandForSlice(ctx context.Context, username, host string, port int) (*exec.Cmd, error) {
	username = strings.TrimSpace(username)
	host = strings.TrimSpace(host)
	if username == "" || !usernameOnNodePattern.MatchString(username) {
		return nil, fmt.Errorf("invalid terminal ssh username")
	}
	if host == "" {
		return nil, fmt.Errorf("terminal ssh target host missing")
	}
	if port < 1 || port > 65535 {
		return nil, fmt.Errorf("terminal ssh target port out of bounds")
	}
	keyPath := strings.TrimSpace(os.Getenv("GPUAAS_SLICE_TERMINAL_SSH_KEY_PATH"))
	if keyPath == "" {
		keyPath = defaultSliceVMTerminalSSHKeyPath
	}
	if _, err := os.Stat(keyPath); err != nil {
		return nil, fmt.Errorf("terminal ssh key unavailable: %w", err)
	}
	target := username + "@" + host
	cmd := exec.CommandContext(ctx, "ssh",
		"-tt",
		"-i", keyPath,
		"-p", strconv.Itoa(port),
		"-o", "BatchMode=yes",
		"-o", "StrictHostKeyChecking=no",
		"-o", "UserKnownHostsFile=/dev/null",
		"-o", "GlobalKnownHostsFile=/dev/null",
		"-o", "LogLevel=ERROR",
		"-o", "ServerAliveInterval=30",
		"-o", "ServerAliveCountMax=3",
		target,
	)
	cmd.Env = append(os.Environ(), "TERM=xterm-256color", "COLORTERM=truecolor")
	return cmd, nil
}

func classifyTerminalOpenError(err error) string {
	text := strings.ToLower(strings.TrimSpace(err.Error()))
	switch {
	case strings.Contains(text, "unknown user"):
		return "terminal_user_not_found"
	case strings.Contains(text, "lookup terminal user"):
		return "terminal_user_lookup_failed"
	case strings.Contains(text, "requires root execution context"):
		return "terminal_runtime_not_root"
	case strings.Contains(text, "start terminal shell"):
		return "terminal_shell_start_failed"
	case strings.Contains(text, "terminal ssh key unavailable"):
		return "terminal_ssh_key_unavailable"
	case strings.Contains(text, "terminal ssh target"):
		return "terminal_ssh_target_invalid"
	default:
		return "terminal_open_failed"
	}
}

func safePTYWinsize(cols, rows int) (*pty.Winsize, bool) {
	if cols <= 0 || rows <= 0 || cols > maxUint16Int || rows > maxUint16Int {
		return nil, false
	}
	return &pty.Winsize{Cols: uint16(cols), Rows: uint16(rows)}, true
}

func containsEnvKey(env []string, key string) bool {
	key = strings.TrimSpace(key)
	if key == "" {
		return false
	}
	prefix := key + "="
	for _, entry := range env {
		if strings.HasPrefix(entry, prefix) {
			return true
		}
	}
	return false
}

func terminalLogPreview(data []byte) string {
	if len(data) == 0 {
		return ""
	}
	const maxPreview = 64
	if len(data) > maxPreview {
		data = data[:maxPreview]
	}
	var b strings.Builder
	b.Grow(len(data))
	for _, ch := range data {
		switch {
		case ch == '\n':
			b.WriteString(`\n`)
		case ch == '\r':
			b.WriteString(`\r`)
		case ch == '\t':
			b.WriteString(`\t`)
		case ch >= 32 && ch <= 126:
			b.WriteByte(ch)
		default:
			b.WriteByte('.')
		}
	}
	return b.String()
}
