//
// Syd: rock-solid application kernel
// src/kernel/signal.rs: Signal syscall handlers
//
// Copyright (c) 2023, 2024, 2025 Ali Polatel <alip@chesswob.org>
//
// SPDX-License-Identifier: GPL-3.0

use libseccomp::ScmpNotifResp;
use nix::{
    errno::Errno,
    unistd::{getpgid, getpgrp, Pid},
};

use crate::hook::UNotifyEventRequest;

pub(crate) fn sys_kill(request: UNotifyEventRequest) -> ScmpNotifResp {
    syscall_signal_handler(request, false, false)
}

pub(crate) fn sys_tgkill(request: UNotifyEventRequest) -> ScmpNotifResp {
    syscall_signal_handler(request, true, true)
}

pub(crate) fn sys_tkill(request: UNotifyEventRequest) -> ScmpNotifResp {
    syscall_signal_handler(request, true, false)
}

pub(crate) fn sys_pidfd_open(request: UNotifyEventRequest) -> ScmpNotifResp {
    syscall_signal_handler(request, true, false)
}

// Maximum PID (wrap-around limit) supported by the kernel.
//
// On 32-bit platforms this is fixed at 32768.
// On 64-bit platforms it can go up to 2^22 (approximately 4 million).
//
// Note, Syd may be built as 32-bit on a 64-bit platform,
// therefore we always use the 64-bit limit.
const PID_MAX_LIMIT: libc::pid_t = 1 << 22;

/// Handles syscalls related to signal handling, protecting the syd
/// process and their threads from signals.
///
/// # Parameters
///
/// - `request`: User notification request from seccomp.
/// - `thread`: true if the system call is directed to a thread rather
///   than a process.
/// - `group`: true if the system call has both progress group id and
///   process id (tgkill), false otherwise.
///
/// - `ScmpNotifResp`: Response indicating the result of the syscall handling.
#[allow(clippy::cognitive_complexity)]
fn syscall_signal_handler(
    request: UNotifyEventRequest,
    thread: bool,
    group: bool,
) -> ScmpNotifResp {
    let req = request.scmpreq;

    // Validate pid/tid.
    #[allow(clippy::cast_possible_truncation)]
    #[allow(clippy::cast_possible_wrap)]
    let pid = req.data.args[0] as libc::pid_t;
    #[allow(clippy::cast_possible_truncation)]
    #[allow(clippy::cast_possible_wrap)]
    let tid = req.data.args[1] as libc::pid_t;

    // See:
    // https://github.com/torvalds/linux/blob/f66bc387efbee59978e076ce9bf123ac353b389c/kernel/signal.c#L1579-L1581
    // wrt. i32::MIN check.
    if pid == i32::MIN {
        return request.fail_syscall(Errno::ESRCH);
    }

    if group && tid == i32::MIN {
        return request.fail_syscall(Errno::ESRCH);
    }

    if !(-PID_MAX_LIMIT..=PID_MAX_LIMIT).contains(&pid) {
        return request.fail_syscall(Errno::ESRCH);
    }

    if group && !(-PID_MAX_LIMIT..=PID_MAX_LIMIT).contains(&tid) {
        return request.fail_syscall(Errno::ESRCH);
    }

    if thread && (pid <= 0 || (group && tid <= 0)) {
        return request.fail_syscall(Errno::EINVAL);
    }

    // Guard syd tasks.
    //
    // SAFETY: Return success when denying for stealth.
    // Otherwise the allowed 0 signal can be misused
    // to identify a Syd process.
    //
    // pid <=0 only for kill/sigqueue here.
    match pid {
        0 => {
            // SAFETY: Guard against group signals.
            // kill(0, 9) -> Send signal to _current_ process group.
            match getpgid(Some(req.pid())) {
                Ok(pgrp) if pgrp == getpgrp() => {
                    // SAFETY: This is a version of killpg().
                    // We must stop this signal if Syd is in
                    // the same process group as the process, otherwise
                    // continue is safe. EACCES is no further
                    // information leak as sig==0 is pass-through.
                    return request.fail_syscall(Errno::EACCES);
                }
                Err(_) => {
                    // ESRCH is no further information leak as
                    // sig==0 is pass-through.
                    return request.fail_syscall(Errno::ESRCH);
                }
                _ => {}
            }
        }
        -1 => {
            // SAFETY: We do not allow mass signaling with -1.
            return request.fail_syscall(Errno::EACCES);
        }
        _ => {}
    }

    // kill and sigqueue support negative PIDs.
    let pid_abs = if thread { pid } else { pid.abs() };

    // Check for Syd tasks.
    let syd = Pid::this().as_raw();

    // SAFETY: Note, we deny with EACCES, rather than returning success
    // because we have a kernel-level bpf filter that _allows_ the
    // respective signaler system call _only when_ the signal is 0.
    // Therefore we're not leaking any further information by returning
    // EACCES here.
    if !thread && syd == pid_abs {
        return request.fail_syscall(Errno::EACCES);
    }

    if thread && syd == pid {
        return request.fail_syscall(Errno::EACCES);
    }

    if thread && group && syd == tid {
        return request.fail_syscall(Errno::EACCES);
    }

    // SAFETY: Check for Syd threads with the abstract PID.
    if !thread && Errno::result(unsafe { libc::syscall(libc::SYS_tgkill, syd, pid_abs, 0) }).is_ok()
    {
        return request.fail_syscall(Errno::EACCES);
    }

    // SAFETY: Check for Syd threads with the PID.
    if thread && Errno::result(unsafe { libc::syscall(libc::SYS_tgkill, syd, pid, 0) }).is_ok() {
        return request.fail_syscall(Errno::EACCES);
    }

    if thread
        && group
        && pid != tid
        && Errno::result(
            // SAFETY: Check for Syd threads with the TID.
            unsafe { libc::syscall(libc::SYS_tgkill, syd, tid, 0) },
        )
        .is_ok()
    {
        return request.fail_syscall(Errno::EACCES);
    }

    // Check signals directed to Syd's process group.
    let syd_pgid = getpgrp().as_raw();
    if !thread && syd_pgid == pid_abs {
        return request.fail_syscall(Errno::EACCES);
    }
    if thread && syd_pgid == pid {
        return request.fail_syscall(Errno::EACCES);
    }
    if thread && group && syd_pgid == tid {
        return request.fail_syscall(Errno::EACCES);
    }

    // SAFETY: This is safe because we haven't dereferenced
    // any pointers during access check.
    unsafe { request.continue_syscall() }
}
