1 package linux
2
3 import (
4 "context"
5 "fmt"
6 "io"
7 "os"
8 "os/exec"
9 "os/user"
10 "strconv"
11 "syscall"
12 "time"
13
14 "golang.conradwood.net/go-easyops/errors"
15 "golang.conradwood.net/go-easyops/utils"
16 "golang.org/x/sys/unix"
17 )
18
19
35
36 type Command interface {
37 SigInt() error
38 SigKill() error
39 SetStdinWriter(r io.Writer)
40 SetStdoutReader(r io.Reader)
41 SetStderrReader(r io.Reader)
42 IsRunning() bool
43 SetDebug(bool)
44 Start(ctx context.Context, com ...string) (ComInstance, error)
45 }
46 type ComInstance interface {
47 Wait(ctx context.Context) error
48 WaitAll(ctx context.Context) error
49 Signal(signal syscall.Signal) error
50 GetCommand() Command
51 }
52 type command struct {
53 stdinwriter io.Writer
54 stdoutreader io.Reader
55 stderrreader io.Reader
56 instance *cominstance
57 debug bool
58 }
59 type cominstance struct {
60 exe []string
61 command *command
62 cgroupdir_cmd string
63 com *exec.Cmd
64 stdout_pipe io.ReadCloser
65 stderr_pipe io.ReadCloser
66 defStdoutReader *comDefaultReader
67 defStderrReader *comDefaultReader
68 }
69
70 func NewCommand() Command {
71 return &command{}
72 }
73 func (c *command) SetDebug(b bool) {
74 c.debug = b
75 }
76 func (c *cominstance) GetCommand() Command {
77 return c.command
78 }
79
80 func (c *command) SetStdinWriter(r io.Writer) {
81 c.stdinwriter = r
82 }
83 func (c *command) SetStdoutReader(r io.Reader) {
84 c.stdoutreader = r
85 }
86 func (c *command) SetStderrReader(r io.Reader) {
87 c.stderrreader = r
88 }
89 func (c *command) IsRunning() bool {
90 return true
91 }
92
93
94 func (c *command) Start(ctx context.Context, com ...string) (ComInstance, error) {
95
96
97 cgroupdir_cmd, err := CreateStandardAdjacentCgroup()
98 if err != nil {
99 return nil, err
100 }
101 c.debugf("Created cgroup \"%s\"\n", cgroupdir_cmd)
102 err = mkdir(cgroupdir_cmd + "/tasks")
103 if err != nil {
104 return nil, err
105 }
106 ci := &cominstance{command: c, cgroupdir_cmd: cgroupdir_cmd}
107 c.instance = ci
108 return ci, ci.start(ctx, com...)
109 }
110 func (ci *cominstance) start(ctx context.Context, com ...string) error {
111 u, err := user.Current()
112 if err != nil {
113 return errors.Wrap(err)
114 }
115
116 uid, err := strconv.Atoi(u.Uid)
117 if err != nil {
118 return errors.Wrap(err)
119 }
120
121 gid, err := strconv.Atoi(u.Gid)
122 if err != nil {
123 return errors.Wrap(err)
124 }
125
126
127 cgroup_fd_path := ci.cgroupdir_cmd
128 cgroup_fd, err := syscall.Open(cgroup_fd_path, unix.O_PATH, 0)
129 if err != nil {
130 return errors.Wrap(err)
131 }
132 fmt.Printf("CgroupFD for \"%s\": %d\n", cgroup_fd_path, cgroup_fd)
133 fmt.Printf("Uid=%d, Gid=%d\n", uid, gid)
134 ci.com = exec.CommandContext(ctx, com[0], com[1:]...)
135 ci.stdout_pipe, err = ci.com.StdoutPipe()
136 if err != nil {
137 return err
138 }
139 ci.defStdoutReader = newDefaultReader(ci.stdout_pipe)
140 ci.stderr_pipe, err = ci.com.StderrPipe()
141 if err != nil {
142 return err
143 }
144 ci.defStderrReader = newDefaultReader(ci.stdout_pipe)
145
146 ci.com.SysProcAttr = &syscall.SysProcAttr{
147 Credential: &syscall.Credential{
148 Uid: uint32(uid),
149 Gid: uint32(gid),
150 NoSetGroups: true,
151 },
152 UseCgroupFD: true,
153 CgroupFD: cgroup_fd,
154 }
155
156
157
158
159 ci.com.SysProcAttr.Credential = nil
160
161 err = ci.com.Start()
162 if err != nil {
163 return errors.Wrap(err)
164 }
165 return nil
166 }
167
168 func (ci *cominstance) Wait(ctx context.Context) error {
169 if ci.com == nil {
170 return nil
171 }
172 err := ci.com.Wait()
173 pids, err := get_pids_for_cgroup(ci.cgroupdir_cmd)
174 if err == nil && len(pids) == 0 {
175 remove_cgroup(ci.cgroupdir_cmd)
176 }
177 return err
178 }
179 func (ci *cominstance) WaitAll(ctx context.Context) error {
180 com_err := ci.Wait(ctx)
181 sig := syscall.SIGINT
182 wait_started := time.Now()
183 for {
184 if time.Since(wait_started) > time.Duration(5)*time.Second {
185 sig = syscall.SIGKILL
186 }
187 pids, err := get_pids_for_cgroup(ci.cgroupdir_cmd)
188 if err != nil {
189 fmt.Printf("Could not get pids for cgroup \"%s\": %s\n", ci.cgroupdir_cmd, err)
190 return err
191 }
192 if len(pids) == 0 {
193 break
194 }
195 for _, pid := range pids {
196 ci.debugf("Sending signal %v to pid %d\n", sig, pid)
197 err = syscall.Kill(int(pid), sig)
198 }
199
200 ci.debugf("Waiting for pid(s): %v\n", pids)
201 waited := false
202 proc, err := os.FindProcess(int(pids[0]))
203 if err != nil {
204 fmt.Printf("Failed to find proc: %s\n", err)
205 } else {
206 _, err := proc.Wait()
207 if err != nil {
208 ci.debugf("failed to wait for proc: %s\n", err)
209 } else {
210 waited = true
211 }
212 }
213 if !waited {
214 time.Sleep(time.Duration(1) * time.Second)
215 }
216 }
217 if com_err != nil {
218 return com_err
219 }
220 fmt.Printf("[go-easyops] All processes exited, now removing cgroup dir (%s)\n", ci.cgroupdir_cmd)
221 remove_cgroup(ci.cgroupdir_cmd)
222 return nil
223 }
224
225 func (c *command) ExitCode() int {
226 return 0
227 }
228 func (c *command) CombinedOutput() []byte {
229 return nil
230 }
231 func (c *command) SigInt() error {
232 fmt.Printf("[go-easyops] sending sigint\n")
233 ci := c.instance
234 if ci == nil {
235 return errors.Errorf("no instance to send signal to")
236 }
237 return ci.Signal(syscall.SIGINT)
238
239 }
240 func (c *command) SigKill() error {
241 fmt.Printf("[go-easyops] sending sigkill\n")
242 ci := c.instance
243 if ci == nil {
244 return errors.Errorf("no instance to send signal to")
245 }
246 return ci.Signal(syscall.SIGKILL)
247 }
248
249 func (ci *cominstance) Signal(sig syscall.Signal) error {
250 pids, err := get_pids_for_cgroup(ci.cgroupdir_cmd)
251 if err != nil {
252 fmt.Printf("Could not get pids for cgroup \"%s\": %s\n", ci.cgroupdir_cmd, err)
253 return err
254 }
255 fmt.Printf("[go-easyops] Cgroupdir \"%s\" has %d pids\n", ci.cgroupdir_cmd, len(pids))
256 for _, pid := range pids {
257 fmt.Printf("[go-easyops] Sending signal %v to pid %d\n", sig, pid)
258 err = syscall.Kill(int(pid), sig)
259 if err != nil {
260 return errors.Wrap(err)
261 }
262 }
263 return nil
264 }
265
266 func mkdir(dir string) error {
267 if utils.FileExists(dir) {
268 return nil
269 }
270 err := os.MkdirAll(dir, 0777)
271 if err != nil {
272 return errors.Wrap(err)
273 }
274 if !utils.FileExists(dir) {
275 return errors.Errorf("failed to create \"%s\"", dir)
276 }
277 return nil
278 }
279
280 func (ci *cominstance) debugf(format string, args ...any) {
281 if !ci.command.debug {
282 return
283 }
284 x := fmt.Sprintf(format, args...)
285 prefix := fmt.Sprintf("[%s] ", ci.exe[0])
286 fmt.Printf("%s%s", prefix, x)
287 }
288 func (c *command) debugf(format string, args ...any) {
289 if !c.debug {
290 return
291 }
292 x := fmt.Sprintf(format, args...)
293 prefix := "[no instance] "
294 if c.instance != nil && c.instance.exe != nil {
295 prefix = fmt.Sprintf("[%s] ", c.instance.exe[0])
296 }
297 fmt.Printf("%s%s", prefix, x)
298 }
299
View as plain text