/*
 * Copyright 2002 Massachusetts Institute of Technology
 *
 * Permission to use, copy, modify, and distribute this software and
 * its documentation for any purpose and without fee is hereby
 * granted, provided that both the above copyright notice and this
 * permission notice appear in all copies, that both the above
 * copyright notice and this permission notice appear in all
 * supporting documentation, and that the name of M.I.T. not be used
 * in advertising or publicity pertaining to distribution of the
 * software without specific, written prior permission.  M.I.T. makes
 * no representations about the suitability of this software for any
 * purpose.  It is provided "as is" without express or implied
 * warranty.
 * 
 * THIS SOFTWARE IS PROVIDED BY M.I.T. ``AS IS''.  M.I.T. DISCLAIMS
 * ALL EXPRESS OR IMPLIED WARRANTIES WITH REGARD TO THIS SOFTWARE,
 * INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
 * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. IN NO EVENT
 * SHALL M.I.T. BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
 * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
 * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF
 * USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
 * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT
 * OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
 * SUCH DAMAGE.
 */

#include <err.h>
#include <stdio.h>
#include <gmp.h>

static mpz_t MINUS1;

int
sprp(mpz_t n, mpz_t base)
{
	mpz_t d, expt, nminus1, result, result2;
	unsigned long r, s;

	mpz_init(nminus1);
	mpz_sub_ui(nminus1, n, 1);

	s = 0;
	do {
		s++;
	} while(mpz_divisible_2exp_p(nminus1, s));
	s--;

	mpz_init(d);
	mpz_fdiv_q_2exp(d, nminus1, s);

	mpz_init(result);
	mpz_powm(result, base, d, n);
	if (mpz_cmp_ui(result, 1) == 0) {
#ifdef DEBUG
		printf("a**d = 1 (mod n)\n");
#endif
out1:
		mpz_clear(result);
		mpz_clear(d);
		mpz_clear(nminus1);
		return (1);
	}
#ifdef DEBUG
	printf("a**d != 1 (mod n)\n");
	mpz_out_str(NULL, 10, base);
	printf("**");
	mpz_out_str(NULL, 10, d);
	printf(" = ");
	mpz_out_str(NULL, 10, result);
	printf(" (mod ");
	mpz_out_str(NULL, 10, n);
	printf(")\n");
#endif /* DEBUG */

	mpz_init_set_ui(expt, 1);
	mpz_init(result2);
	for (r = 0; r < s; r++) {
		mpz_powm(result2, result, expt, n);
		if (mpz_cmp(result2, nminus1) == 0) {
#ifdef DEBUG
			printf("(a**d)**(2**%lu) = -1 (mod n)\n", r);
#endif
			mpz_clear(result2);
			mpz_clear(expt);
			goto out1;
		}
#ifdef DEBUG
		printf("a**(d*2**%lu) != -1 (mod n)\n", r);
		printf("(a**d)**");
		mpz_out_str(NULL, 10, expt);
		printf(" = ");
		mpz_out_str(NULL, 10, result2);
		printf(" (mod n)\n");
#endif /* DEBUG */
		mpz_mul_ui(expt, expt, 2);
	}

	mpz_clear(expt);
	mpz_clear(result);
	mpz_clear(result2);
	mpz_clear(d);
	mpz_clear(nminus1);
	return (0);
}

int
main(int argc, char **argv)
{
	mpz_t n, base;
	int rv;

	if (argc != 3) {
		errx(1, "incorrect argument count\nusage:\n\tsprp n base");
	}

	mpz_init_set_si(MINUS1, -1);

	if (mpz_init_set_str(n, argv[1], 10) != 0) {
		errx(1, "mpz_init_set_str(n) error");
	}

	if (mpz_init_set_str(base, argv[2], 10) != 0) {
		errx(1, "mpz_init_set_str(base) error");
	}

	rv = sprp(n, base);
	printf("%s is%s a %s-SPRP\n", argv[1], (rv ? "" : " not"), argv[2]);
	return (!rv);
}
