#include <stdio.h>
#include <string.h>
#include <sys/types.h>
#include <sys/stat.h>
#include <fcntl.h>
#include <pwd.h>
#include <stdlib.h>
#include <unistd.h>
#include <syslog.h>

void usage(const char * const msg)
{
  if(msg)
    fprintf(stderr, "Error: %s\n\n", msg);
  fprintf(stderr, "Usage: init_su [-l] user -c command\n");
  exit(1);
}

int main(int argc, char **argv)
{
  int i, fd;
  int login = 0;
  char *command = NULL, *user = NULL, *shell = NULL, *nu_argv[4];
  struct passwd *pw;

  int int_c = 0;
  while(int_c != -1)
  {
    int_c = getopt(argc, argv, "-lc:s:");
    switch(int_c)
    {
      case 1:
        if(!strcmp(optarg, "-"))
        {
          login = 1;
        }
        else
        {
          user = optarg;
        }
      break;
      case 'l':
        login = 1;
      break;
      case 's':
        shell = optarg;
      break;
      case 'c':
        command = optarg;
      break;
    }
  }
  if(!user || !command)
    usage(NULL);
  pw = getpwnam(user);
  if(!pw)
    usage("User unknown.");
  if(setregid(pw->pw_gid, pw->pw_gid))
    usage("Can't setgid(), are you root?");
  if(setreuid(pw->pw_uid, pw->pw_uid))
    usage("Can't setuid(), are you root?");
  if(!shell)
    shell = pw->pw_shell;
  if(login)
  {
    nu_argv[0] = strrchr(shell, '/');
    if(!nu_argv[0])
      usage("Bad shell.");
    nu_argv[0] = strdup(nu_argv[0]);
    nu_argv[0][0] = '-';
  }
  else
    nu_argv[0] = shell;
  nu_argv[1] = "-c";
  nu_argv[2] = command;
  nu_argv[3] = NULL;
  close(0);
  for(i = 3; i < 1024; i++)
    close(i);
  openlog("initrc_su", LOG_CONS | LOG_NOWAIT, LOG_DAEMON);
  fd = open("/dev/null", O_RDWR);
  if(fd == -1)
  {
    syslog(LOG_ERR, "Can't open /dev/null when trying to execute program %s", command);
    return 1;
  }
  for(i = 0; i < 3; i++)
  {
    struct stat sbuf;
    if(i != fd && (fstat(i, &sbuf) == -1 || (!S_ISREG(sbuf.st_mode) && !S_ISFIFO(sbuf.st_mode)) ))
    {
      close(i);
      if(dup2(fd, i) != i)
      {
        syslog(LOG_ERR, "Can't dup2() when trying to execute program %s", command);
        return 1;
      }
    }
  }
  if(fd >= 3)
    close(fd);
  setsid();  /* it's OK if this fails as we get the right result anyway */
  execv(shell, nu_argv);
  syslog(LOG_ERR, "Can't exec program %s", command);
  return 1;
}
