Logo Search packages:      
Sourcecode: ecl version File versions  Download package

pexpr.c

/* Program for computing integer expressions using the GNU Multiple Precision
   Arithmetic Library.

Copyright 1997, 1999, 2000, 2001, 2002 Free Software Foundation, Inc.

This program is free software; you can redistribute it and/or modify it under
the terms of the GNU General Public License as published by the Free Software
Foundation; either version 2 of the License, or (at your option) any later
version.

This program is distributed in the hope that it will be useful, but WITHOUT ANY
WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
PARTICULAR PURPOSE.  See the GNU General Public License for more details.

You should have received a copy of the GNU General Public License along with
this program; if not, write to the Free Software Foundation, Inc., 59 Temple
Place - Suite 330, Boston, MA 02111-1307, USA.  */


/* This expressions evaluator works by building an expression tree (using a
   recursive descent parser) which is then evaluated.  The expression tree is
   useful since we want to optimize certain expressions (like a^b % c).

   Usage: pexpr [options] expr ...
   (Assuming you called the executable `pexpr' of course.)

   Command line options:

   -b        print output in binary
   -o        print output in octal
   -d        print output in decimal (the default)
   -x        print output in hexadecimal
   -b<NUM>   print output in base NUM
   -t        print timing information
   -html     output html
   -wml      output wml
   -nosplit  do not split long lines each 60th digit
*/

/* Define LIMIT_RESOURCE_USAGE if you want to make sure the program doesn't
   use up extensive resources (cpu, memory).  Useful for the GMP demo on the
   GMP web site, since we cannot load the server too much.  */

#include "pexpr-config.h"

#include <string.h>
#include <stdio.h>
#include <stdlib.h>
#include <setjmp.h>
#include <signal.h>
#include <ctype.h>

#include <time.h>
#include <sys/types.h>
#include <sys/time.h>
#if HAVE_SYS_RESOURCE_H
#include <sys/resource.h>
#endif

#include "gmp.h"

/* SunOS 4 and HPUX 9 don't define a canonical SIGSTKSZ, use a default. */
#ifndef SIGSTKSZ
#define SIGSTKSZ  4096
#endif


#define TIME(t,func)                                        \
  do { int __t0, __times, __t, __tmp;                             \
    __times = 1;                                      \
    __t0 = cputime ();                                      \
    {func;}                                           \
    __tmp = cputime () - __t0;                                    \
    while (__tmp < 100)                                     \
      {                                                     \
      __times <<= 1;                                        \
      __t0 = cputime ();                                    \
      for (__t = 0; __t < __times; __t++)                   \
        {func;}                                       \
      __tmp = cputime () - __t0;                            \
      }                                                     \
    (t) = (double) __tmp / __times;                         \
  } while (0)

/* GMP version 1.x compatibility.  */
#if ! (__GNU_MP_VERSION >= 2)
typedef MP_INT __mpz_struct;
typedef __mpz_struct mpz_t[1];
typedef __mpz_struct *mpz_ptr;
#define mpz_fdiv_q      mpz_div
#define mpz_fdiv_r      mpz_mod
#define mpz_tdiv_q_2exp mpz_div_2exp
#define mpz_sgn(Z) ((Z)->size < 0 ? -1 : (Z)->size > 0)
#endif

/* GMP version 2.0 compatibility.  */
#if ! (__GNU_MP_VERSION > 2 || __GNU_MP_VERSION_MINOR >= 1)
#define mpz_swap(a,b) \
  do { __mpz_struct __t; __t = *a; *a = *b; *b = __t;} while (0)
#endif

jmp_buf errjmpbuf;

enum op_t {NOP, LIT, NEG, NOT, PLUS, MINUS, MULT, DIV, MOD, REM, INVMOD, POW,
         AND, IOR, XOR, SLL, SRA, POPCNT, HAMDIST, GCD, LCM, SQRT, ROOT, FAC,
         LOG, LOG2, FERMAT, MERSENNE, FIBONACCI, RANDOM, NEXTPRIME, BINOM};

/* Type for the expression tree.  */
struct expr
{
  enum op_t op;
  union
  {
    struct {struct expr *lhs, *rhs;} ops;
    mpz_t val;
  } operands;
};

typedef struct expr *expr_t;

void cleanup_and_exit __GMP_PROTO ((int));

char *skipspace __GMP_PROTO ((char *));
void makeexp __GMP_PROTO ((expr_t *, enum op_t, expr_t, expr_t));
void free_expr __GMP_PROTO ((expr_t));
char *expr __GMP_PROTO ((char *, expr_t *));
char *term __GMP_PROTO ((char *, expr_t *));
char *power __GMP_PROTO ((char *, expr_t *));
char *factor __GMP_PROTO ((char *, expr_t *));
int match __GMP_PROTO ((char *, char *));
int matchp __GMP_PROTO ((char *, char *));
int cputime __GMP_PROTO ((void));

void mpz_eval_expr __GMP_PROTO ((mpz_ptr, expr_t));
void mpz_eval_mod_expr __GMP_PROTO ((mpz_ptr, expr_t, mpz_ptr));

char *error;
int flag_print = 1;
int print_timing = 0;
int flag_html = 0;
int flag_wml = 0;
int flag_splitup_output = 0;
char *newline = "";
gmp_randstate_t rstate;



/* cputime() returns user CPU time measured in milliseconds.  */
#if ! HAVE_CPUTIME
#if HAVE_GETRUSAGE
int
cputime (void)
{
  struct rusage rus;

  getrusage (0, &rus);
  return rus.ru_utime.tv_sec * 1000 + rus.ru_utime.tv_usec / 1000;
}
#else
#if HAVE_CLOCK
int
cputime (void)
{
  if (CLOCKS_PER_SEC < 100000)
    return clock () * 1000 / CLOCKS_PER_SEC;
  return clock () / (CLOCKS_PER_SEC / 1000);
}
#else
int
cputime (void)
{
  return 0;
}
#endif
#endif
#endif


int
stack_downwards_helper (char *xp)
{
  char  y;
  return &y < xp;
}
int
stack_downwards_p (void)
{
  char  x;
  return stack_downwards_helper (&x);
}


void
setup_error_handler (void)
{
#if HAVE_SIGACTION
  struct sigaction act;
  act.sa_handler = cleanup_and_exit;
  sigemptyset (&(act.sa_mask));
#define SIGNAL(sig)  sigaction (sig, &act, NULL)
#else
  struct { int sa_flags } act;
#define SIGNAL(sig)  signal (sig, cleanup_and_exit)
#endif
  act.sa_flags = 0;

  /* Set up a stack for signal handling.  A typical cause of error is stack
     overflow, and in such situation a signal can not be delivered on the
     overflown stack.  */
#if HAVE_SIGALTSTACK
  {
    /* AIX uses stack_t, MacOS uses struct sigaltstack, various other
       systems have both. */
#if HAVE_STACK_T
    stack_t s;
#else
    struct sigaltstack s;
#endif
    s.ss_sp = malloc (SIGSTKSZ);
    s.ss_size = SIGSTKSZ;
    s.ss_flags = 0;
    if (sigaltstack (&s, NULL) != 0)
      perror("sigaltstack");
    act.sa_flags = SA_ONSTACK;
  }
#else
#if HAVE_SIGSTACK
  {
    struct sigstack s;
    s.ss_sp = malloc (SIGSTKSZ);
    if (stack_downwards_p ())
      s.ss_sp += SIGSTKSZ;
    s.ss_onstack = 0;
    if (sigstack (&s, NULL) != 0)
      perror("sigstack");
    act.sa_flags = SA_ONSTACK;
  }
#else
#endif
#endif

#ifdef LIMIT_RESOURCE_USAGE
  {
    struct rlimit limit;

    limit.rlim_cur = limit.rlim_max = 0;
    setrlimit (RLIMIT_CORE, &limit);

    limit.rlim_cur = 3;
    limit.rlim_max = 4;
    setrlimit (RLIMIT_CPU, &limit);

    limit.rlim_cur = limit.rlim_max = 16 * 1024 * 1024;
    setrlimit (RLIMIT_DATA, &limit);

    getrlimit (RLIMIT_STACK, &limit);
    limit.rlim_cur = 4 * 1024 * 1024;
    setrlimit (RLIMIT_STACK, &limit);

    SIGNAL (SIGXCPU);
  }
#endif /* LIMIT_RESOURCE_USAGE */

  SIGNAL (SIGILL);
  SIGNAL (SIGSEGV);
#ifdef SIGBUS /* not in mingw */
  SIGNAL (SIGBUS);
#endif
  SIGNAL (SIGFPE);
  SIGNAL (SIGABRT);
}

int
main (int argc, char **argv)
{
  struct expr *e;
  int i;
  mpz_t r;
  int errcode = 0;
  char *str;
  int base = 10;

  setup_error_handler ();

  gmp_randinit (rstate, GMP_RAND_ALG_LC, 128);

  {
#if HAVE_GETTIMEOFDAY
    struct timeval tv;
    gettimeofday (&tv, NULL);
    gmp_randseed_ui (rstate, tv.tv_sec + tv.tv_usec);
#else
    time_t t;
    time (&t);
    gmp_randseed_ui (rstate, t);
#endif
  }

  mpz_init (r);

  while (argc > 1 && argv[1][0] == '-')
    {
      char *arg = argv[1];

      if (arg[1] >= '0' && arg[1] <= '9')
      break;

      if (arg[1] == 't')
      print_timing = 1;
      else if (arg[1] == 'b' && arg[2] >= '0' && arg[2] <= '9')
      {
        base = atoi (arg + 2);
        if (base < 2 || base > 36)
          {
            fprintf (stderr, "error: invalid output base\n");
            exit (-1);
          }
      }
      else if (arg[1] == 'b' && arg[2] == 0)
      base = 2;
      else if (arg[1] == 'x' && arg[2] == 0)
      base = 16;
      else if (arg[1] == 'X' && arg[2] == 0)
      base = -16;
      else if (arg[1] == 'o' && arg[2] == 0)
      base = 8;
      else if (arg[1] == 'd' && arg[2] == 0)
      base = 10;
      else if (strcmp (arg, "-html") == 0)
      {
        flag_html = 1;
        newline = "<br>";
      }
      else if (strcmp (arg, "-wml") == 0)
      {
        flag_wml = 1;
        newline = "<br/>";
      }
      else if (strcmp (arg, "-split") == 0)
      {
        flag_splitup_output = 1;
      }
      else if (strcmp (arg, "-noprint") == 0)
      {
        flag_print = 0;
      }
      else
      {
        fprintf (stderr, "error: unknown option `%s'\n", arg);
        exit (-1);
      }
      argv++;
      argc--;
    }

  for (i = 1; i < argc; i++)
    {
      int s;
      int jmpval;

      /* Set up error handler for parsing expression.  */
      jmpval = setjmp (errjmpbuf);
      if (jmpval != 0)
      {
        fprintf (stderr, "error: %s%s\n", error, newline);
        fprintf (stderr, "       %s%s\n", argv[i], newline);
        if (! flag_html)
          {
            /* ??? Dunno how to align expression position with arrow in
             HTML ??? */
            fprintf (stderr, "       ");
            for (s = jmpval - (long) argv[i]; --s >= 0; )
            putc (' ', stderr);
            fprintf (stderr, "^\n");
          }

        errcode |= 1;
        continue;
      }

      str = expr (argv[i], &e);

      if (str[0] != 0)
      {
        fprintf (stderr,
               "error: garbage where end of expression expected%s\n",
               newline);
        fprintf (stderr, "       %s%s\n", argv[i], newline);
        if (! flag_html)
          {
            /* ??? Dunno how to align expression position with arrow in
             HTML ??? */
            fprintf (stderr, "        ");
            for (s = str - argv[i]; --s; )
            putc (' ', stderr);
            fprintf (stderr, "^\n");
          }

        errcode |= 1;
        free_expr (e);
        continue;
      }

      /* Set up error handler for evaluating expression.  */
      if (setjmp (errjmpbuf))
      {
        fprintf (stderr, "error: %s%s\n", error, newline);
        fprintf (stderr, "       %s%s\n", argv[i], newline);
        if (! flag_html)
          {
            /* ??? Dunno how to align expression position with arrow in
             HTML ??? */
            fprintf (stderr, "       ");
            for (s = str - argv[i]; --s >= 0; )
            putc (' ', stderr);
            fprintf (stderr, "^\n");
          }

        errcode |= 2;
        continue;
      }

      if (print_timing)
      {
        double t;
        TIME (t, mpz_eval_expr (r, e));
        printf ("computation took %.2f ms%s\n", t, newline);
      }
      else
      mpz_eval_expr (r, e);

      if (flag_print)
      {
        size_t out_len;
        char *tmp, *s;

        out_len = mpz_sizeinbase (r, base >= 0 ? base : -base) + 2;
        tmp = malloc (out_len);

        if (print_timing)
          {
            double t;
            printf ("output conversion ");
            TIME (t, mpz_get_str (tmp, base, r));
            printf ("took %.2f ms%s\n", t, newline);
          }
        else
          mpz_get_str (tmp, base, r);

        out_len = strlen (tmp);
        if (flag_splitup_output)
          {
            for (s = tmp; out_len > 60; s += 60)
            {
              fwrite (s, 1, 60, stdout);
              printf ("%s\n", newline);
              out_len -= 60;
            }

            fwrite (s, 1, out_len, stdout);
          }
        else
          {
            fwrite (tmp, 1, out_len, stdout);
          }

        free (tmp);
        printf ("%s\n", newline);
      }
      else
      {
        printf ("result is approximately %ld digits%s\n",
              (long) mpz_sizeinbase (r, 10), newline);
      }

      free_expr (e);
    }

  exit (errcode);
}

char *
expr (char *str, expr_t *e)
{
  expr_t e2;

  str = skipspace (str);
  if (str[0] == '+')
    {
      str = term (str + 1, e);
    }
  else if (str[0] == '-')
    {
      str = term (str + 1, e);
      makeexp (e, NEG, *e, NULL);
    }
  else if (str[0] == '~')
    {
      str = term (str + 1, e);
      makeexp (e, NOT, *e, NULL);
    }
  else
    {
      str = term (str, e);
    }

  for (;;)
    {
      str = skipspace (str);
      switch (str[0])
      {
      case 'p':
        if (match ("plus", str))
          {
            str = term (str + 4, &e2);
            makeexp (e, PLUS, *e, e2);
          }
        else
          return str;
        break;
      case 'm':
        if (match ("minus", str))
          {
            str = term (str + 5, &e2);
            makeexp (e, MINUS, *e, e2);
          }
        else
          return str;
        break;
      case '+':
        str = term (str + 1, &e2);
        makeexp (e, PLUS, *e, e2);
        break;
      case '-':
        str = term (str + 1, &e2);
        makeexp (e, MINUS, *e, e2);
        break;
      default:
        return str;
      }
    }
}

char *
term (char *str, expr_t *e)
{
  expr_t e2;

  str = power (str, e);
  for (;;)
    {
      str = skipspace (str);
      switch (str[0])
      {
      case 'm':
        if (match ("mul", str))
          {
            str = power (str + 3, &e2);
            makeexp (e, MULT, *e, e2);
            break;
          }
        if (match ("mod", str))
          {
            str = power (str + 3, &e2);
            makeexp (e, MOD, *e, e2);
            break;
          }
        return str;
      case 'd':
        if (match ("div", str))
          {
            str = power (str + 3, &e2);
            makeexp (e, DIV, *e, e2);
            break;
          }
        return str;
      case 'r':
        if (match ("rem", str))
          {
            str = power (str + 3, &e2);
            makeexp (e, REM, *e, e2);
            break;
          }
        return str;
      case 'i':
        if (match ("invmod", str))
          {
            str = power (str + 6, &e2);
            makeexp (e, REM, *e, e2);
            break;
          }
        return str;
      case 't':
        if (match ("times", str))
          {
            str = power (str + 5, &e2);
            makeexp (e, MULT, *e, e2);
            break;
          }
        if (match ("thru", str))
          {
            str = power (str + 4, &e2);
            makeexp (e, DIV, *e, e2);
            break;
          }
        if (match ("through", str))
          {
            str = power (str + 7, &e2);
            makeexp (e, DIV, *e, e2);
            break;
          }
        return str;
      case '*':
        str = power (str + 1, &e2);
        makeexp (e, MULT, *e, e2);
        break;
      case '/':
        str = power (str + 1, &e2);
        makeexp (e, DIV, *e, e2);
        break;
      case '%':
        str = power (str + 1, &e2);
        makeexp (e, MOD, *e, e2);
        break;
      default:
        return str;
      }
    }
}

char *
power (char *str, expr_t *e)
{
  expr_t e2;

  str = factor (str, e);
  while (str[0] == '!')
    {
      str++;
      makeexp (e, FAC, *e, NULL);
    }
  str = skipspace (str);
  if (str[0] == '^')
    {
      str = power (str + 1, &e2);
      makeexp (e, POW, *e, e2);
    }
  return str;
}

int
match (char *s, char *str)
{
  char *ostr = str;
  int i;

  for (i = 0; s[i] != 0; i++)
    {
      if (str[i] != s[i])
      return 0;
    }
  str = skipspace (str + i);
  return str - ostr;
}

int
matchp (char *s, char *str)
{
  char *ostr = str;
  int i;

  for (i = 0; s[i] != 0; i++)
    {
      if (str[i] != s[i])
      return 0;
    }
  str = skipspace (str + i);
  if (str[0] == '(')
    return str - ostr + 1;
  return 0;
}

struct functions
{
  char *spelling;
  enum op_t op;
  int arity; /* 1 or 2 means real arity; 0 means arbitrary.  */
};

struct functions fns[] =
{
  {"sqrt", SQRT, 1},
#if __GNU_MP_VERSION >= 2
  {"root", ROOT, 2},
  {"popc", POPCNT, 1},
  {"hamdist", HAMDIST, 2},
#endif
  {"gcd", GCD, 0},
#if __GNU_MP_VERSION > 2 || __GNU_MP_VERSION_MINOR >= 1
  {"lcm", LCM, 0},
#endif
  {"and", AND, 0},
  {"ior", IOR, 0},
#if __GNU_MP_VERSION > 2 || __GNU_MP_VERSION_MINOR >= 1
  {"xor", XOR, 0},
#endif
  {"plus", PLUS, 0},
  {"pow", POW, 2},
  {"minus", MINUS, 2},
  {"mul", MULT, 0},
  {"div", DIV, 2},
  {"mod", MOD, 2},
  {"rem", REM, 2},
#if __GNU_MP_VERSION >= 2
  {"invmod", INVMOD, 2},
#endif
  {"log", LOG, 2},
  {"log2", LOG2, 1},
  {"F", FERMAT, 1},
  {"M", MERSENNE, 1},
  {"fib", FIBONACCI, 1},
  {"Fib", FIBONACCI, 1},
  {"random", RANDOM, 1},
  {"nextprime", NEXTPRIME, 1},
  {"binom", BINOM, 2},
  {"binomial", BINOM, 2},
  {"", NOP, 0}
};

char *
factor (char *str, expr_t *e)
{
  expr_t e1, e2;

  str = skipspace (str);

  if (isalpha (str[0]))
    {
      int i;
      int cnt;

      for (i = 0; fns[i].op != NOP; i++)
      {
        if (fns[i].arity == 1)
          {
            cnt = matchp (fns[i].spelling, str);
            if (cnt != 0)
            {
              str = expr (str + cnt, &e1);
              str = skipspace (str);
              if (str[0] != ')')
                {
                  error = "expected `)'";
                  longjmp (errjmpbuf, (int) (long) str);
                }
              makeexp (e, fns[i].op, e1, NULL);
              return str + 1;
            }
          }
      }

      for (i = 0; fns[i].op != NOP; i++)
      {
        if (fns[i].arity != 1)
          {
            cnt = matchp (fns[i].spelling, str);
            if (cnt != 0)
            {
              str = expr (str + cnt, &e1);
              str = skipspace (str);

              if (str[0] != ',')
                {
                  error = "expected `,' and another operand";
                  longjmp (errjmpbuf, (int) (long) str);
                }

              str = skipspace (str + 1);
              str = expr (str, &e2);
              str = skipspace (str);

              if (fns[i].arity == 0)
                {
                  while (str[0] == ',')
                  {
                    makeexp (&e1, fns[i].op, e1, e2);
                    str = skipspace (str + 1);
                    str = expr (str, &e2);
                    str = skipspace (str);
                  }
                }

              if (str[0] != ')')
                {
                  error = "expected `)'";
                  longjmp (errjmpbuf, (int) (long) str);
                }

              makeexp (e, fns[i].op, e1, e2);
              return str + 1;
            }
          }
      }
    }

  if (str[0] == '(')
    {
      str = expr (str + 1, e);
      str = skipspace (str);
      if (str[0] != ')')
      {
        error = "expected `)'";
        longjmp (errjmpbuf, (int) (long) str);
      }
      str++;
    }
  else if (str[0] >= '0' && str[0] <= '9')
    {
      expr_t res;
      char *s, *sc;

      res = malloc (sizeof (struct expr));
      res -> op = LIT;
      mpz_init (res->operands.val);

      s = str;
      while (isalnum (str[0]))
      str++;
      sc = malloc (str - s + 1);
      memcpy (sc, s, str - s);
      sc[str - s] = 0;

      mpz_set_str (res->operands.val, sc, 0);
      *e = res;
      free (sc);
    }
  else
    {
      error = "operand expected";
      longjmp (errjmpbuf, (int) (long) str);
    }
  return str;
}

char *
skipspace (char *str)
{
  while (str[0] == ' ')
    str++;
  return str;
}

/* Make a new expression with operation OP and right hand side
   RHS and left hand side lhs.  Put the result in R.  */
void
makeexp (expr_t *r, enum op_t op, expr_t lhs, expr_t rhs)
{
  expr_t res;
  res = malloc (sizeof (struct expr));
  res -> op = op;
  res -> operands.ops.lhs = lhs;
  res -> operands.ops.rhs = rhs;
  *r = res;
  return;
}

/* Free the memory used by expression E.  */
void
free_expr (expr_t e)
{
  if (e->op != LIT)
    {
      free_expr (e->operands.ops.lhs);
      if (e->operands.ops.rhs != NULL)
      free_expr (e->operands.ops.rhs);
    }
  else
    {
      mpz_clear (e->operands.val);
    }
}

/* Evaluate the expression E and put the result in R.  */
void
mpz_eval_expr (mpz_ptr r, expr_t e)
{
  mpz_t lhs, rhs;

  switch (e->op)
    {
    case LIT:
      mpz_set (r, e->operands.val);
      return;
    case PLUS:
      mpz_init (lhs); mpz_init (rhs);
      mpz_eval_expr (lhs, e->operands.ops.lhs);
      mpz_eval_expr (rhs, e->operands.ops.rhs);
      mpz_add (r, lhs, rhs);
      mpz_clear (lhs); mpz_clear (rhs);
      return;
    case MINUS:
      mpz_init (lhs); mpz_init (rhs);
      mpz_eval_expr (lhs, e->operands.ops.lhs);
      mpz_eval_expr (rhs, e->operands.ops.rhs);
      mpz_sub (r, lhs, rhs);
      mpz_clear (lhs); mpz_clear (rhs);
      return;
    case MULT:
      mpz_init (lhs); mpz_init (rhs);
      mpz_eval_expr (lhs, e->operands.ops.lhs);
      mpz_eval_expr (rhs, e->operands.ops.rhs);
      mpz_mul (r, lhs, rhs);
      mpz_clear (lhs); mpz_clear (rhs);
      return;
    case DIV:
      mpz_init (lhs); mpz_init (rhs);
      mpz_eval_expr (lhs, e->operands.ops.lhs);
      mpz_eval_expr (rhs, e->operands.ops.rhs);
      mpz_fdiv_q (r, lhs, rhs);
      mpz_clear (lhs); mpz_clear (rhs);
      return;
    case MOD:
      mpz_init (rhs);
      mpz_eval_expr (rhs, e->operands.ops.rhs);
      mpz_abs (rhs, rhs);
      mpz_eval_mod_expr (r, e->operands.ops.lhs, rhs);
      mpz_clear (rhs);
      return;
    case REM:
      /* Check if lhs operand is POW expression and optimize for that case.  */
      if (e->operands.ops.lhs->op == POW)
      {
        mpz_t powlhs, powrhs;
        mpz_init (powlhs);
        mpz_init (powrhs);
        mpz_init (rhs);
        mpz_eval_expr (powlhs, e->operands.ops.lhs->operands.ops.lhs);
        mpz_eval_expr (powrhs, e->operands.ops.lhs->operands.ops.rhs);
        mpz_eval_expr (rhs, e->operands.ops.rhs);
        mpz_powm (r, powlhs, powrhs, rhs);
        if (mpz_cmp_si (rhs, 0L) < 0)
          mpz_neg (r, r);
        mpz_clear (powlhs);
        mpz_clear (powrhs);
        mpz_clear (rhs);
        return;
      }

      mpz_init (lhs); mpz_init (rhs);
      mpz_eval_expr (lhs, e->operands.ops.lhs);
      mpz_eval_expr (rhs, e->operands.ops.rhs);
      mpz_fdiv_r (r, lhs, rhs);
      mpz_clear (lhs); mpz_clear (rhs);
      return;
#if __GNU_MP_VERSION >= 2
    case INVMOD:
      mpz_init (lhs); mpz_init (rhs);
      mpz_eval_expr (lhs, e->operands.ops.lhs);
      mpz_eval_expr (rhs, e->operands.ops.rhs);
      mpz_invert (r, lhs, rhs);
      mpz_clear (lhs); mpz_clear (rhs);
      return;
#endif
    case POW:
      mpz_init (lhs); mpz_init (rhs);
      mpz_eval_expr (lhs, e->operands.ops.lhs);
      if (mpz_cmpabs_ui (lhs, 1) <= 0)
      {
        /* For 0^rhs and 1^rhs, we just need to verify that
           rhs is well-defined.  For (-1)^rhs we need to
           determine (rhs mod 2).  For simplicity, compute
           (rhs mod 2) for all three cases.  */
        expr_t two, et;
        two = malloc (sizeof (struct expr));
        two -> op = LIT;
        mpz_init_set_ui (two->operands.val, 2L);
        makeexp (&et, MOD, e->operands.ops.rhs, two);
        e->operands.ops.rhs = et;
      }

      mpz_eval_expr (rhs, e->operands.ops.rhs);
      if (mpz_cmp_si (rhs, 0L) == 0)
      /* x^0 is 1 */
      mpz_set_ui (r, 1L);
      else if (mpz_cmp_si (lhs, 0L) == 0)
      /* 0^y (where y != 0) is 0 */
      mpz_set_ui (r, 0L);
      else if (mpz_cmp_ui (lhs, 1L) == 0)
      /* 1^y is 1 */
      mpz_set_ui (r, 1L);
      else if (mpz_cmp_si (lhs, -1L) == 0)
      /* (-1)^y just depends on whether y is even or odd */
      mpz_set_si (r, (mpz_get_ui (rhs) & 1) ? -1L : 1L);
      else if (mpz_cmp_si (rhs, 0L) < 0)
      /* x^(-n) is 0 */
      mpz_set_ui (r, 0L);
      else
      {
        unsigned long int cnt;
        unsigned long int y;
        /* error if exponent does not fit into an unsigned long int.  */
        if (mpz_cmp_ui (rhs, ~(unsigned long int) 0) > 0)
          goto pow_err;

        y = mpz_get_ui (rhs);
        /* x^y == (x/(2^c))^y * 2^(c*y) */
#if __GNU_MP_VERSION >= 2
        cnt = mpz_scan1 (lhs, 0);
#else
        cnt = 0;
#endif
        if (cnt != 0)
          {
            if (y * cnt / cnt != y)
            goto pow_err;
            mpz_tdiv_q_2exp (lhs, lhs, cnt);
            mpz_pow_ui (r, lhs, y);
            mpz_mul_2exp (r, r, y * cnt);
          }
        else
          mpz_pow_ui (r, lhs, y);
      }
      mpz_clear (lhs); mpz_clear (rhs);
      return;
    pow_err:
      error = "result of `pow' operator too large";
      mpz_clear (lhs); mpz_clear (rhs);
      longjmp (errjmpbuf, 1);
    case GCD:
      mpz_init (lhs); mpz_init (rhs);
      mpz_eval_expr (lhs, e->operands.ops.lhs);
      mpz_eval_expr (rhs, e->operands.ops.rhs);
      mpz_gcd (r, lhs, rhs);
      mpz_clear (lhs); mpz_clear (rhs);
      return;
#if __GNU_MP_VERSION > 2 || __GNU_MP_VERSION_MINOR >= 1
    case LCM:
      mpz_init (lhs); mpz_init (rhs);
      mpz_eval_expr (lhs, e->operands.ops.lhs);
      mpz_eval_expr (rhs, e->operands.ops.rhs);
      mpz_lcm (r, lhs, rhs);
      mpz_clear (lhs); mpz_clear (rhs);
      return;
#endif
    case AND:
      mpz_init (lhs); mpz_init (rhs);
      mpz_eval_expr (lhs, e->operands.ops.lhs);
      mpz_eval_expr (rhs, e->operands.ops.rhs);
      mpz_and (r, lhs, rhs);
      mpz_clear (lhs); mpz_clear (rhs);
      return;
    case IOR:
      mpz_init (lhs); mpz_init (rhs);
      mpz_eval_expr (lhs, e->operands.ops.lhs);
      mpz_eval_expr (rhs, e->operands.ops.rhs);
      mpz_ior (r, lhs, rhs);
      mpz_clear (lhs); mpz_clear (rhs);
      return;
#if __GNU_MP_VERSION > 2 || __GNU_MP_VERSION_MINOR >= 1
    case XOR:
      mpz_init (lhs); mpz_init (rhs);
      mpz_eval_expr (lhs, e->operands.ops.lhs);
      mpz_eval_expr (rhs, e->operands.ops.rhs);
      mpz_xor (r, lhs, rhs);
      mpz_clear (lhs); mpz_clear (rhs);
      return;
#endif
    case NEG:
      mpz_eval_expr (r, e->operands.ops.lhs);
      mpz_neg (r, r);
      return;
    case NOT:
      mpz_eval_expr (r, e->operands.ops.lhs);
      mpz_com (r, r);
      return;
    case SQRT:
      mpz_init (lhs);
      mpz_eval_expr (lhs, e->operands.ops.lhs);
      if (mpz_sgn (lhs) < 0)
      {
        error = "cannot take square root of negative numbers";
        mpz_clear (lhs);
        longjmp (errjmpbuf, 1);
      }
      mpz_sqrt (r, lhs);
      return;
#if __GNU_MP_VERSION > 2 || __GNU_MP_VERSION_MINOR >= 1
    case ROOT:
      mpz_init (lhs); mpz_init (rhs);
      mpz_eval_expr (lhs, e->operands.ops.lhs);
      mpz_eval_expr (rhs, e->operands.ops.rhs);
      if (mpz_sgn (rhs) <= 0)
      {
        error = "cannot take non-positive root orders";
        mpz_clear (lhs); mpz_clear (rhs);
        longjmp (errjmpbuf, 1);
      }
      if (mpz_sgn (lhs) < 0 && (mpz_get_ui (rhs) & 1) == 0)
      {
        error = "cannot take even root orders of negative numbers";
        mpz_clear (lhs); mpz_clear (rhs);
        longjmp (errjmpbuf, 1);
      }

      {
      unsigned long int nth = mpz_get_ui (rhs);
      if (mpz_cmp_ui (rhs, ~(unsigned long int) 0) > 0)
        {
          /* If we are asked to take an awfully large root order, cheat and
             ask for the largest order we can pass to mpz_root.  This saves
             some error prone special cases.  */
          nth = ~(unsigned long int) 0;
        }
      mpz_root (r, lhs, nth);
      }
      mpz_clear (lhs); mpz_clear (rhs);
      return;
#endif
    case FAC:
      mpz_eval_expr (r, e->operands.ops.lhs);
      if (mpz_size (r) > 1)
      {
        error = "result of `!' operator too large";
        longjmp (errjmpbuf, 1);
      }
      mpz_fac_ui (r, mpz_get_ui (r));
      return;
#if __GNU_MP_VERSION >= 2
    case POPCNT:
      mpz_eval_expr (r, e->operands.ops.lhs);
      { long int cnt;
      cnt = mpz_popcount (r);
      mpz_set_si (r, cnt);
      }
      return;
    case HAMDIST:
      { long int cnt;
        mpz_init (lhs); mpz_init (rhs);
      mpz_eval_expr (lhs, e->operands.ops.lhs);
      mpz_eval_expr (rhs, e->operands.ops.rhs);
      cnt = mpz_hamdist (lhs, rhs);
      mpz_clear (lhs); mpz_clear (rhs);
      mpz_set_si (r, cnt);
      }
      return;
#endif
    case LOG2:
      mpz_eval_expr (r, e->operands.ops.lhs);
      { unsigned long int cnt;
      if (mpz_sgn (r) <= 0)
        {
          error = "logarithm of non-positive number";
          longjmp (errjmpbuf, 1);
        }
      cnt = mpz_sizeinbase (r, 2);
      mpz_set_ui (r, cnt - 1);
      }
      return;
    case LOG:
      { unsigned long int cnt;
      mpz_init (lhs); mpz_init (rhs);
      mpz_eval_expr (lhs, e->operands.ops.lhs);
      mpz_eval_expr (rhs, e->operands.ops.rhs);
      if (mpz_sgn (lhs) <= 0)
        {
          error = "logarithm of non-positive number";
          mpz_clear (lhs); mpz_clear (rhs);
          longjmp (errjmpbuf, 1);
        }
      if (mpz_cmp_ui (rhs, 256) >= 0)
        {
          error = "logarithm base too large";
          mpz_clear (lhs); mpz_clear (rhs);
          longjmp (errjmpbuf, 1);
        }
      cnt = mpz_sizeinbase (lhs, mpz_get_ui (rhs));
      mpz_set_ui (r, cnt - 1);
      mpz_clear (lhs); mpz_clear (rhs);
      }
      return;
    case FERMAT:
      {
      unsigned long int t;
      mpz_init (lhs);
      mpz_eval_expr (lhs, e->operands.ops.lhs);
      t = (unsigned long int) 1 << mpz_get_ui (lhs);
      if (mpz_cmp_ui (lhs, ~(unsigned long int) 0) > 0 || t == 0)
        {
          error = "too large Mersenne number index";
          mpz_clear (lhs);
          longjmp (errjmpbuf, 1);
        }
      mpz_set_ui (r, 1);
      mpz_mul_2exp (r, r, t);
      mpz_add_ui (r, r, 1);
      mpz_clear (lhs);
      }
      return;
    case MERSENNE:
      mpz_init (lhs);
      mpz_eval_expr (lhs, e->operands.ops.lhs);
      if (mpz_cmp_ui (lhs, ~(unsigned long int) 0) > 0)
      {
        error = "too large Mersenne number index";
        mpz_clear (lhs);
        longjmp (errjmpbuf, 1);
      }
      mpz_set_ui (r, 1);
      mpz_mul_2exp (r, r, mpz_get_ui (lhs));
      mpz_sub_ui (r, r, 1);
      mpz_clear (lhs);
      return;
    case FIBONACCI:
      { mpz_t t;
      unsigned long int n, i;
      mpz_init (lhs);
      mpz_eval_expr (lhs, e->operands.ops.lhs);
      if (mpz_sgn (lhs) <= 0 || mpz_cmp_si (lhs, 1000000000) > 0)
        {
          error = "Fibonacci index out of range";
          mpz_clear (lhs);
          longjmp (errjmpbuf, 1);
        }
      n = mpz_get_ui (lhs);
      mpz_clear (lhs);

#if __GNU_MP_VERSION > 2 || __GNU_MP_VERSION_MINOR >= 1
      mpz_fib_ui (r, n);
#else
      mpz_init_set_ui (t, 1);
      mpz_set_ui (r, 1);

      if (n <= 2)
        mpz_set_ui (r, 1);
      else
        {
          for (i = 3; i <= n; i++)
            {
            mpz_add (t, t, r);
            mpz_swap (t, r);
            }
        }
      mpz_clear (t);
#endif
      }
      return;
    case RANDOM:
      {
      unsigned long int n;
      mpz_init (lhs);
      mpz_eval_expr (lhs, e->operands.ops.lhs);
      if (mpz_sgn (lhs) <= 0 || mpz_cmp_si (lhs, 1000000000) > 0)
        {
          error = "random number size out of range";
          mpz_clear (lhs);
          longjmp (errjmpbuf, 1);
        }
      n = mpz_get_ui (lhs);
      mpz_clear (lhs);
      mpz_urandomb (r, rstate, n);
      }
      return;
    case NEXTPRIME:
      {
      mpz_eval_expr (r, e->operands.ops.lhs);
      mpz_nextprime (r, r);
      }
      return;
    case BINOM:
      mpz_init (lhs); mpz_init (rhs);
      mpz_eval_expr (lhs, e->operands.ops.lhs);
      mpz_eval_expr (rhs, e->operands.ops.rhs);
      {
      unsigned long int k;
      if (mpz_cmp_ui (rhs, ~(unsigned long int) 0) > 0)
        {
          error = "k too large in (n over k) expression";
          mpz_clear (lhs); mpz_clear (rhs);
          longjmp (errjmpbuf, 1);
        }
      k = mpz_get_ui (rhs);
      mpz_bin_ui (r, lhs, k);
      }
      mpz_clear (lhs); mpz_clear (rhs);
      return;
    default:
      abort ();
    }
}

/* Evaluate the expression E modulo MOD and put the result in R.  */
void
mpz_eval_mod_expr (mpz_ptr r, expr_t e, mpz_ptr mod)
{
  mpz_t lhs, rhs;

  switch (e->op)
    {
      case POW:
      mpz_init (lhs); mpz_init (rhs);
      mpz_eval_mod_expr (lhs, e->operands.ops.lhs, mod);
      mpz_eval_expr (rhs, e->operands.ops.rhs);
      mpz_powm (r, lhs, rhs, mod);
      mpz_clear (lhs); mpz_clear (rhs);
      return;
      case PLUS:
      mpz_init (lhs); mpz_init (rhs);
      mpz_eval_mod_expr (lhs, e->operands.ops.lhs, mod);
      mpz_eval_mod_expr (rhs, e->operands.ops.rhs, mod);
      mpz_add (r, lhs, rhs);
      if (mpz_cmp_si (r, 0L) < 0)
        mpz_add (r, r, mod);
      else if (mpz_cmp (r, mod) >= 0)
        mpz_sub (r, r, mod);
      mpz_clear (lhs); mpz_clear (rhs);
      return;
      case MINUS:
      mpz_init (lhs); mpz_init (rhs);
      mpz_eval_mod_expr (lhs, e->operands.ops.lhs, mod);
      mpz_eval_mod_expr (rhs, e->operands.ops.rhs, mod);
      mpz_sub (r, lhs, rhs);
      if (mpz_cmp_si (r, 0L) < 0)
        mpz_add (r, r, mod);
      else if (mpz_cmp (r, mod) >= 0)
        mpz_sub (r, r, mod);
      mpz_clear (lhs); mpz_clear (rhs);
      return;
      case MULT:
      mpz_init (lhs); mpz_init (rhs);
      mpz_eval_mod_expr (lhs, e->operands.ops.lhs, mod);
      mpz_eval_mod_expr (rhs, e->operands.ops.rhs, mod);
      mpz_mul (r, lhs, rhs);
      mpz_mod (r, r, mod);
      mpz_clear (lhs); mpz_clear (rhs);
      return;
      default:
      mpz_init (lhs);
      mpz_eval_expr (lhs, e);
      mpz_mod (r, lhs, mod);
      mpz_clear (lhs);
      return;
    }
}

void
cleanup_and_exit (int sig)
{
  switch (sig) {
#ifdef LIMIT_RESOURCE_USAGE
  case SIGXCPU:
    printf ("expression took too long to evaluate%s\n", newline);
    break;
#endif
  case SIGFPE:
    printf ("divide by zero%s\n", newline);
    break;
  default:
    printf ("expression required too much memory to evaluate%s\n", newline);
    break;
  }
  exit (-2);
}

Generated by  Doxygen 1.6.0   Back to index