diff --git a/pkgs/uvms-guest/guest.py b/pkgs/uvms-guest/guest.py index 87b3d1a..575245e 100644 --- a/pkgs/uvms-guest/guest.py +++ b/pkgs/uvms-guest/guest.py @@ -3,74 +3,117 @@ import os import select import socket import subprocess +import sys -def handle_run(run: dict) -> dict: - res = {} - text = run.get("text", False) - env = { - **os.environ, - "PATH": ":".join( - os.environ.get("PATH", "").split(":") + run.get("EXTRA_PATH", []) - ), - } - proc = None - try: - proc = subprocess.Popen( - req["run"]["argv"], - text=text, - env=env, - cwd="/home/user", - stdin=subprocess.PIPE, - stdout=subprocess.PIPE, - ) - res["status"] = "exec succeeded" - except Exception as e: - res["status"] = "exec failed" - res["exception"] = repr(e) - res["pid"] = getattr(proc, "pid", None) - try: - if proc is not None: - proc.wait(0.125) - res["long_running"] = False - res["returncode"] = getattr(proc, "returncode", None) - except subprocess.TimeoutExpired: - res["long_running"] = True - return res, proc +class Processes: + def __init__(self): + self.processes = [] + self.sources = [] + self.liveness_fds = dict() + self.client_fds = set() + + def popen(self, *args, **kwargs): + a, b = socket.socketpair() + pass_fds = [*kwargs.get("pass_fds", ()), b.fileno()] + proc = subprocess.Popen(*args, **kwargs, pass_fds=pass_fds) + self.processes.append(proc) + self.sources.append(a) + assert a.fileno() not in self.liveness_fds + self.liveness_fds[a.fileno()] = proc + b.close() + return proc + + def handle_run(self, run: dict) -> dict: + res = {} + text = run.get("text", False) + env = { + **os.environ, + "PATH": ":".join( + [ + *os.environ.get("PATH", "").split(":"), + *run.get( + "EXTRA_PATH", + [], + ), + ], + ), + } + proc = None + try: + proc = self.popen( + req["run"]["argv"], + text=text, + env=env, + cwd="/home/user", + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + ) + res["status"] = "exec succeeded" + except Exception as e: + print(e) + res["status"] = "exec failed" + res["exception"] = repr(e) + res["pid"] = getattr(proc, "pid", None) + try: + if proc is not None: + proc.wait(0.125) + res["long_running"] = False + res["returncode"] = getattr(proc, "returncode", None) + except subprocess.TimeoutExpired: + res["long_running"] = True + return res, proc + + def accept_vsock(self, s): + con, (cid, port) = serv.accept() + assert cid == 2, cid + self.sources.append(con) + self.client_fds.insert(con.fileno()) + return con, (cid, port) if __name__ == "__main__": + ps = Processes() serv = socket.fromfd(3, socket.AF_VSOCK, socket.SOCK_STREAM) - - procs = [] - conns = [serv] + ps.sources.append(serv) while True: - rr, rw, xs = select.select(conns, [], []) + rr, rw, xs = select.select(ps.sources, [], ps.sources) + for con in (*rr, *xs): + if con.fileno() in ps.liveness_fds: + assert con.recv(128) == b"" + proc = ps.liveness_fds[con.fileno()] + proc.wait() + assert proc.returncode is not None, proc + print(f"{proc} has terminated, shutting down") + sys.exit(proc.returncode) for con in rr: if con is serv: - con, (cid, port) = serv.accept() - assert cid == 2, cid - conns.append(con) - continue - req = con.recv(8192) - # IDK why but I keep getting empty messages - if req == b"": - continue - try: - req = json.loads(req) - print(f"Received {req=}") - except json.JSONDecodeError as e: - print(f"Couldn't interpret {req=}: {e}") - continue - if "run" in req: - res, proc = handle_run(req["run"]) - procs.append(proc) + con, _ = ps.accept_vsock(serv) + print(f"Open [{con.fileno()}]") + if con.fileno() in ps.liveness_fds: + assert False, "Must already be processed" + elif con.fileno() in ps.client_fds: + req = con.recv(8192) + # IDK why but I keep getting empty messages + if req == b"": + print(f"Lost [{con.fileno()}]") + continue + try: + req = json.loads(req) + print(f"Received {req=}") + except json.JSONDecodeError as e: + print(f"Couldn't interpret {req=}: {e}") + continue + if "run" in req: + res, proc = ps.handle_run(req["run"]) + else: + res = {"status": "unknown command"} + _, rw, _ = select.select([], [con], []) + assert rw, rw + res = json.dumps(res).encode("utf8") + print(f"Responding with {res=}") + con.send(res) else: - res = {"status": "unknown command"} - _, rw, _ = select.select([], [con], []) - assert rw, rw - res = json.dumps(res).encode("utf8") - print(f"Responding with {res=}") - con.send(res) + assert False, con.fileno() diff --git a/pkgs/uvms/uvms.py b/pkgs/uvms/uvms.py index ab71c6e..c473f11 100644 --- a/pkgs/uvms/uvms.py +++ b/pkgs/uvms/uvms.py @@ -452,18 +452,21 @@ def removing(*paths): os.remove(p) +@contextmanager def connect_ch_vsock( vsock_sock_path, port: int, type=socket.SOCK_STREAM, blocking=True, ) -> socket.socket: + os.makedirs(os.path.dirname(vsock_sock_path), exist_ok=True) s = socket.socket(socket.AF_UNIX, type, 0) s.setblocking(blocking) s.connect(vsock_sock_path) - s.send(b"CONNECT %d\n" % port) - return s + with removing(vsock_sock_path): + s.send(b"CONNECT %d\n" % port) + yield s @contextmanager @@ -473,15 +476,14 @@ def listen_ch_vsock( type=socket.SOCK_STREAM, blocking=True, ) -> socket.socket: + os.makedirs(os.path.dirname(vsock_sock_path), exist_ok=True) listen_path = vsock_sock_path + "_%d" % port s = socket.socket(socket.AF_UNIX, type, 0) s.setblocking(blocking) s.bind(listen_path) s.listen() - try: + with removing(listen_path): yield s - finally: - os.remove(listen_path) def main(args, args_next, cleanup, ps): @@ -589,7 +591,7 @@ def main(args, args_next, cleanup, ps): ps.exec(*ch_remote, "info") with ready_sock: - ready_sock.settimeout(16.0) + ready_sock.settimeout(20.0) try: con, _ = ready_sock.accept() except: # noqa: E722 diff --git a/profiles/baseImage.nix b/profiles/baseImage.nix index 87f8df5..8ba9767 100644 --- a/profiles/baseImage.nix +++ b/profiles/baseImage.nix @@ -39,9 +39,9 @@ in ./on-failure.nix ]; config = { - some.failure-handler.enable = true; + # some.failure-handler.enable = true; hardware.graphics.enable = true; - # boot.kernelPackages = pkgs.linuxPackagesFor uvmsPkgs.linux-uvm; + boot.kernelPackages = pkgs.linuxPackagesFor uvmsPkgs.linux-uvm; # boot.isContainer = true; boot.initrd.kernelModules = [ "drm" @@ -256,14 +256,26 @@ in partOf = [ "uvms-guest.service" ]; }; systemd.services."uvms-guest" = { + requiredBy = [ "multi-user.target" ]; + onFailure = [ "shutdown.service" ]; serviceConfig = { User = "user"; Group = "users"; ExecStart = "${lib.getExe uvmsPkgs.uvms-guest}"; + ExecStop = [ + "/run/current-system/sw/bin/echo GUEST DOWN" + "/run/current-system/sw/bin/systemctl poweroff" + ]; + StandardOutput = "journal+console"; + StandardError = "journal+console"; + Restart = "no"; + }; + }; + systemd.services."shutdown" = { + serviceConfig = { + ExecStart = [ "/run/current-system/sw/bin/systemctl poweroff" ]; StandardOutput = "journal+console"; StandardError = "journal+console"; - Restart = "on-failure"; - RestartSec = 5; }; }; @@ -371,7 +383,7 @@ in options = { size = mkOption { type = types.int; - default = 1536 * 1048576; + default = 3 * 1024 * 1048576; }; shared = mkOption { type = types.bool;