 * MPI ping-pong benchmark
 *
 * usage: mpirun -np 2 mpibench [nbytes]
 *
 * Define the CHECK macro to check data integrity.
 */

#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <mpi.h>

#define NLOOPS		1000
#define ALIGN		4096

/*
 * local functions
 */
static int		compare();

/*
 * static variables
 */
static double		vtmp[NLOOPS];


main(argc, argv)

int			argc;
char			*argv[];

{
	int		i, j;
	double		start, stop;
	double		ovrhd, min, med, max;
	int		nbytes = 0;
	int		rank, size;
	MPI_Status	status;
	char		*buf;

	setvbuf(stdout, 0, _IOLBF, 0);

	MPI_Init(&argc, &argv);
	MPI_Comm_rank(MPI_COMM_WORLD, &rank);
	MPI_Comm_size(MPI_COMM_WORLD, &size);

	if (size != 2) {
		if ( ! rank) printf("mpibench: must have two processes\n");
		MPI_Finalize();
		exit(0);
	}

	nbytes = (argc > 1) ? atoi(argv[1]) : 0;
	if (nbytes < 0) nbytes = 0;

	buf = (char *) malloc(nbytes + (ALIGN - 1));
	if (buf == 0) {
		MPI_Abort(MPI_COMM_WORLD, MPI_ERR_BUFFER);
		exit(1);
	}

	buf = (char *) ((((unsigned long) buf) + (ALIGN - 1)) & ~(ALIGN - 1));
	memset(buf, 0, nbytes);

	if (rank == 0) {
		printf("ping-pong %d bytes ...\n", nbytes);

		for (i = 0; i < NLOOPS; i++) vtmp[i] = 0;
		ovrhd = MAXFLOAT;

		for (i = 0; i < 1000; i++) {
			start = MPI_Wtime();
			stop = MPI_Wtime();
			stop -= start;
			if (stop < ovrhd) ovrhd = stop;
		}

		for (i = 0; i < 3; i++) {
			MPI_Send(buf, nbytes, MPI_CHAR, 1, 1, MPI_COMM_WORLD);
			MPI_Recv(buf, nbytes, MPI_CHAR,
					1, 1, MPI_COMM_WORLD, &status);
		}

		for (i = 0; i < NLOOPS; i++) {
#ifdef CHECK
			for (j = 0; j < nbytes; j++) {
				buf[j] = (char) (j + i);
			}
#endif
			start = MPI_Wtime();

			MPI_Send(buf, nbytes, MPI_CHAR,
					1, 1000 + i, MPI_COMM_WORLD);
#ifdef CHECK
			memset(buf, 0, nbytes);
#endif
			MPI_Recv(buf, nbytes, MPI_CHAR,
					1, 2000 + i, MPI_COMM_WORLD, &status);

			stop = MPI_Wtime();
#ifdef CHECK
			for (j = 0; j < nbytes; j++) {
				if (buf[j] != (char) (j + i)) {
					printf("error: buf[%d] = %d, not %d\n",
						j, buf[j], j = i);
					break;
				}
			}
#endif

			vtmp[i] = (stop - start - ovrhd) / 2;
		}

		qsort(vtmp, NLOOPS, sizeof(double), compare);

		min = vtmp[0];
		med = vtmp[NLOOPS / 2];
		max = vtmp[NLOOPS - 1];

		printf("%d bytes: %.2f %.2f %.2f usec/msg\n",
			nbytes, min * 1000000, med * 1000000, max * 1000000);
		if (nbytes > 0) {
			printf("%d bytes: %.2f %.2f %.2f MB/sec\n",
				nbytes, nbytes / 1000000.0 / min,
				nbytes / 1000000.0 / med,
				nbytes / 1000000.0 / max);
		}
	}
	else {
		for (i = 0; i < 3; i++) {
			MPI_Recv(buf, nbytes, MPI_CHAR,
					0, 1, MPI_COMM_WORLD, &status);
			MPI_Send(buf, nbytes, MPI_CHAR, 0, 1, MPI_COMM_WORLD);
		}

		for (i = 0; i < NLOOPS; i++) {
			MPI_Recv(buf, nbytes, MPI_CHAR,
					0, 1000 + i, MPI_COMM_WORLD, &status);
			MPI_Send(buf, nbytes, MPI_CHAR,
					0, 2000 + i, MPI_COMM_WORLD);
		}
	}

	MPI_Finalize();
	exit(0);
}

/*
 *	compare
 *
 *	Function:	- compare two doubles
 *	Accepts:	- ptr to two doubles
 *	Returns:	- -1/0/1
 */
static int
compare(p1, p2)

double			*p1, *p2;

{
	return( (*p1 > *p2) ? 1 : ( (*p1 < *p2) ? -1 : 0) );
}
