From c9d537cd341c96f21d749006002291e30a383463 Mon Sep 17 00:00:00 2001
From: Justin Pryzby <pryzbyj@telsasoft.com>
Date: Mon, 21 Dec 2020 00:11:43 -0600
Subject: [PATCH 5/7] union with a CompressionAlgorithm alg

---
 src/bin/pg_dump/compress_io.c | 219 ++++++++++++++++------------------
 1 file changed, 106 insertions(+), 113 deletions(-)

diff --git a/src/bin/pg_dump/compress_io.c b/src/bin/pg_dump/compress_io.c
index 10db22ff88..ae1ae7b0d0 100644
--- a/src/bin/pg_dump/compress_io.c
+++ b/src/bin/pg_dump/compress_io.c
@@ -614,22 +614,27 @@ WriteDataToArchiveNone(ArchiveHandle *AH, CompressorState *cs,
  */
 struct cfp
 {
-	FILE	   *uncompressedfp;
+	CompressionAlgorithm alg;
+
+	union {
+		FILE	   *fp;
+
 #ifdef HAVE_LIBZ
-	gzFile		compressedfp;
+		gzFile		gzfp;
 #endif
 
-#ifdef HAVE_LIBZSTD // XXX: this should be a union with a CompressionAlgorithm alg?
-	/* This is a normal file to which we read/write compressed data */
-	struct {
-		FILE			*fp;
-		// XXX: use one separate ZSTD_CStream per thread: disable on windows ?
-		ZSTD_CStream	*cstream;
-		ZSTD_DStream	*dstream;
-		ZSTD_outBuffer	output;
-		ZSTD_inBuffer	input;
-	} zstd;
+#ifdef HAVE_LIBZSTD
+		struct {
+			/* This is a normal file to which we read/write compressed data */
+			FILE			*fp;
+			// XXX: use one separate ZSTD_CStream per thread: disable on windows ?
+			ZSTD_CStream	*cstream;
+			ZSTD_DStream	*dstream;
+			ZSTD_outBuffer	output;
+			ZSTD_inBuffer	input;
+		} zstd;
 #endif
+	} u;
 
 };
 
@@ -730,13 +735,7 @@ cfopen(const char *path, const char *mode, Compress *compression)
 {
 	cfp		   *fp = pg_malloc(sizeof(cfp));
 
-	fp->uncompressedfp = NULL;
-#ifdef HAVE_LIBZ
-	fp->compressedfp = NULL;
-#endif
-#ifdef HAVE_LIBZSTD
-	fp->zstd.fp = NULL;
-#endif
+	fp->alg = compression->alg;
 
 	switch (compression->alg)
 	{
@@ -749,15 +748,15 @@ cfopen(const char *path, const char *mode, Compress *compression)
 
 			snprintf(mode_compression, sizeof(mode_compression), "%s%d",
 					 mode, compression->level);
-			fp->compressedfp = gzopen(path, mode_compression);
+			fp->u.gzfp = gzopen(path, mode_compression);
 		}
 		else
 		{
 			/* don't specify a level, just use the zlib default */
-			fp->compressedfp = gzopen(path, mode);
+			fp->u.gzfp = gzopen(path, mode);
 		}
 
-		if (fp->compressedfp == NULL)
+		if (fp->u.gzfp == NULL)
 		{
 			free_keep_errno(fp);
 			fp = NULL;
@@ -769,29 +768,29 @@ cfopen(const char *path, const char *mode, Compress *compression)
 
 #ifdef HAVE_LIBZSTD
 	case COMPR_ALG_ZSTD:
-		fp->zstd.fp = fopen(path, mode);
+		fp->u.zstd.fp = fopen(path, mode);
 		// XXX: save the compression params
-		if (fp->zstd.fp == NULL)
+		if (fp->u.zstd.fp == NULL)
 		{
 			free_keep_errno(fp);
 			fp = NULL;
 		}
 		else if (strchr(mode, 'w'))
 		{
-			fp->zstd.dstream = NULL;
-			fp->zstd.output.size = ZSTD_CStreamOutSize(); // XXX
-			fp->zstd.output.dst = pg_malloc0(fp->zstd.output.size);
-			fp->zstd.cstream = ZSTD_createCStream();
-			if (fp->zstd.cstream == NULL)
+			fp->u.zstd.dstream = NULL;
+			fp->u.zstd.output.size = ZSTD_CStreamOutSize(); // XXX
+			fp->u.zstd.output.dst = pg_malloc0(fp->u.zstd.output.size);
+			fp->u.zstd.cstream = ZSTD_createCStream();
+			if (fp->u.zstd.cstream == NULL)
 				fatal("could not initialize compression library");
 		}
 		else if (strchr(mode, 'r'))
 		{
-			fp->zstd.cstream = NULL;
-			fp->zstd.input.size = ZSTD_DStreamOutSize(); // XXX
-			fp->zstd.input.src = pg_malloc0(fp->zstd.input.size);
-			fp->zstd.dstream = ZSTD_createDStream();
-			if (fp->zstd.dstream == NULL)
+			fp->u.zstd.cstream = NULL;
+			fp->u.zstd.input.size = ZSTD_DStreamOutSize(); // XXX
+			fp->u.zstd.input.src = pg_malloc0(fp->u.zstd.input.size);
+			fp->u.zstd.dstream = ZSTD_createDStream();
+			if (fp->u.zstd.dstream == NULL)
 				fatal("could not initialize compression library");
 		} // XXX else: bad mode
 		return fp;
@@ -799,8 +798,8 @@ cfopen(const char *path, const char *mode, Compress *compression)
 #endif
 
 	case COMPR_ALG_NONE:
-		fp->uncompressedfp = fopen(path, mode);
-		if (fp->uncompressedfp == NULL)
+		fp->u.fp = fopen(path, mode);
+		if (fp->u.fp == NULL)
 		{
 			free_keep_errno(fp);
 			fp = NULL;
@@ -824,26 +823,26 @@ cfread(void *ptr, int size, cfp *fp)
 		return 0;
 
 #ifdef HAVE_LIBZ
-	if (fp->compressedfp)
+	if (fp->alg == COMPR_ALG_LIBZ)
 	{
-		ret = gzread(fp->compressedfp, ptr, size);
-		if (ret != size && !gzeof(fp->compressedfp))
+		ret = gzread(fp->u.gzfp, ptr, size);
+		if (ret != size && !gzeof(fp->u.gzfp))
 		{
 			int			errnum;
-			const char *errmsg = gzerror(fp->compressedfp, &errnum);
+			const char *errmsg = gzerror(fp->u.gzfp, &errnum);
 
 			fatal("could not read from input file: %s",
 				  errnum == Z_ERRNO ? strerror(errno) : errmsg);
 		}
+		return ret;
 	}
-	else
 #endif
 
 #ifdef HAVE_LIBZSTD
-	if (fp->zstd.fp)
+	if (fp->alg == COMPR_ALG_ZSTD)
 	{
-		ZSTD_outBuffer	*output = &fp->zstd.output;
-		ZSTD_inBuffer	*input = &fp->zstd.input;
+		ZSTD_outBuffer	*output = &fp->u.zstd.output;
+		ZSTD_inBuffer	*input = &fp->u.zstd.input;
 		size_t			input_size = ZSTD_DStreamInSize();
 		size_t			res, cnt;
 
@@ -852,7 +851,7 @@ cfread(void *ptr, int size, cfp *fp)
 		output->pos = 0;
 
 		/* read compressed data */
-		while ((cnt = fread(unconstify(void *, input->src), 1, input_size, fp->zstd.fp)))
+		while ((cnt = fread(unconstify(void *, input->src), 1, input_size, fp->u.zstd.fp)))
 		{
 			input->size = cnt;
 			input->pos = 0;
@@ -860,7 +859,7 @@ cfread(void *ptr, int size, cfp *fp)
 			for ( ; input->pos < input->size; )
 			{
 				/* decompress */
-				res = ZSTD_decompressStream(fp->zstd.dstream, output, input);
+				res = ZSTD_decompressStream(fp->u.zstd.dstream, output, input);
 				if (res == 0 || output->pos == output->size)
 					break;
 				if (ZSTD_isError(res))
@@ -871,16 +870,13 @@ cfread(void *ptr, int size, cfp *fp)
 				break; /* We read all the data that fits */
 		}
 
-		ret = output->pos;
+		return output->pos;
 	}
-	else
 #endif
 
-	{
-		ret = fread(ptr, 1, size, fp->uncompressedfp);
-		if (ret != size && !feof(fp->uncompressedfp))
-			READ_ERROR_EXIT(fp->uncompressedfp);
-	}
+	ret = fread(ptr, 1, size, fp->u.fp);
+	if (ret != size && !feof(fp->u.fp))
+		READ_ERROR_EXIT(fp->u.fp);
 	return ret;
 }
 
@@ -888,16 +884,16 @@ int
 cfwrite(const void *ptr, int size, cfp *fp)
 {
 #ifdef HAVE_LIBZ
-	if (fp->compressedfp)
-		return gzwrite(fp->compressedfp, ptr, size);
-	else
+	if (fp->alg == COMPR_ALG_LIBZ)
+		return gzwrite(fp->u.gzfp, ptr, size);
 #endif
+
 #ifdef HAVE_LIBZSTD
-	if (fp->zstd.fp)
+	if (fp->alg == COMPR_ALG_ZSTD)
 	{
 		size_t      res, cnt;
-		ZSTD_outBuffer	*output = &fp->zstd.output;
-		ZSTD_inBuffer	*input = &fp->zstd.input;
+		ZSTD_outBuffer	*output = &fp->u.zstd.output;
+		ZSTD_inBuffer	*input = &fp->u.zstd.input;
 
 		input->src = ptr;
 		input->size = size;
@@ -907,21 +903,20 @@ cfwrite(const void *ptr, int size, cfp *fp)
 		while (input->pos != input->size)
 		{
 			output->pos = 0;
-			res = ZSTD_compressStream2(fp->zstd.cstream, output, input, ZSTD_e_continue);
+			res = ZSTD_compressStream2(fp->u.zstd.cstream, output, input, ZSTD_e_continue);
 			if (ZSTD_isError(res))
 				fatal("could not compress data: %s", ZSTD_getErrorName(res));
 
-			cnt = fwrite(output->dst, 1, output->pos, fp->zstd.fp);
+			cnt = fwrite(output->dst, 1, output->pos, fp->u.zstd.fp);
 			if (cnt != output->pos)
 				fatal("could not write data: %s", strerror(errno));
 		}
 
 		return size;
 	}
-	else
 #endif
 
-		return fwrite(ptr, 1, size, fp->uncompressedfp);
+		return fwrite(ptr, 1, size, fp->u.fp);
 }
 
 int
@@ -930,39 +925,38 @@ cfgetc(cfp *fp)
 	int			ret;
 
 #ifdef HAVE_LIBZ
-	if (fp->compressedfp)
+	if (fp->alg == COMPR_ALG_LIBZ)
 	{
-		ret = gzgetc(fp->compressedfp);
+		ret = gzgetc(fp->u.gzfp);
 		if (ret == EOF)
 		{
-			if (!gzeof(fp->compressedfp))
+			if (!gzeof(fp->u.gzfp))
 				fatal("could not read from input file: %s", strerror(errno));
 			else
 				fatal("could not read from input file: end of file");
 		}
+		return ret;
 	}
-	else
 #endif
+
 #ifdef HAVE_LIBZSTD
-	if (fp->zstd.fp)
+	if (fp->alg == COMPR_ALG_ZSTD)
 	{
 		if (cfread(&ret, 1, fp) != 1)
 		{
-			if (feof(fp->zstd.fp))
+			if (feof(fp->u.zstd.fp))
 				fatal("could not read from input file: end of file");
 			else
 				fatal("could not read from input file: %s", strerror(errno));
 		}
 fprintf(stderr, "cfgetc %d\n", ret);
+		return ret;
 	}
 #endif
 
-	{
-		ret = fgetc(fp->uncompressedfp);
-		if (ret == EOF)
-			READ_ERROR_EXIT(fp->uncompressedfp);
-	}
-
+	ret = fgetc(fp->u.fp);
+	if (ret == EOF)
+		READ_ERROR_EXIT(fp->u.fp);
 	return ret;
 }
 
@@ -970,12 +964,12 @@ char *
 cfgets(cfp *fp, char *buf, int len)
 {
 #ifdef HAVE_LIBZ
-	if (fp->compressedfp)
-		return gzgets(fp->compressedfp, buf, len);
-	else
+	if (fp->alg == COMPR_ALG_LIBZ)
+		return gzgets(fp->u.gzfp, buf, len);
 #endif
+
 #ifdef HAVE_LIBZSTD
-	if (fp->zstd.fp)
+	if (fp->alg == COMPR_ALG_ZSTD)
 	{
 		int res;
 		res = cfread(buf, len, fp);
@@ -984,9 +978,9 @@ cfgets(cfp *fp, char *buf, int len)
 			*strchr(buf, '\n') = '\0';
 		return res > 0 ? buf : 0;
 	}
-	else
 #endif
-		return fgets(buf, len, fp->uncompressedfp);
+
+		return fgets(buf, len, fp->u.fp);
 }
 
 int
@@ -1000,56 +994,55 @@ cfclose(cfp *fp)
 		return EOF;
 	}
 #ifdef HAVE_LIBZ
-	if (fp->compressedfp)
+	if (fp->alg == COMPR_ALG_LIBZ)
 	{
-		result = gzclose(fp->compressedfp);
-		fp->compressedfp = NULL;
+		result = gzclose(fp->u.gzfp);
+		fp->u.gzfp = NULL;
+		return result;
 	}
-	else
 #endif
+
 #ifdef HAVE_LIBZSTD
-	if (fp->zstd.fp)
+	if (fp->alg == COMPR_ALG_ZSTD)
 	{
-		ZSTD_outBuffer	*output = &fp->zstd.output;
-		ZSTD_inBuffer	*input = &fp->zstd.input;
+		ZSTD_outBuffer	*output = &fp->u.zstd.output;
+		ZSTD_inBuffer	*input = &fp->u.zstd.input;
 		size_t res, cnt;
 
-		if (fp->zstd.cstream)
+		if (fp->u.zstd.cstream)
 		{
 			for (;;)
 			{
 				output->pos = 0;
-				res = ZSTD_compressStream2(fp->zstd.cstream, output, input, ZSTD_e_end);
+				res = ZSTD_compressStream2(fp->u.zstd.cstream, output, input, ZSTD_e_end);
 				if (ZSTD_isError(res))
 					fatal("could not compress data: %s", ZSTD_getErrorName(res));
-				cnt = fwrite(output->dst, 1, output->pos, fp->zstd.fp);
+				cnt = fwrite(output->dst, 1, output->pos, fp->u.zstd.fp);
 				if (cnt != output->pos)
 					fatal("could not write data: %s", strerror(errno));
 				if (res == 0)
 					break;
 			}
 
-			ZSTD_freeCStream(fp->zstd.cstream);
-			pg_free(fp->zstd.output.dst);
+			ZSTD_freeCStream(fp->u.zstd.cstream);
+			pg_free(fp->u.zstd.output.dst);
 		}
 
-		if (fp->zstd.dstream)
+		if (fp->u.zstd.dstream)
 		{
-			ZSTD_freeDStream(fp->zstd.dstream);
-			pg_free(unconstify(void *, fp->zstd.input.src));
+			ZSTD_freeDStream(fp->u.zstd.dstream);
+			pg_free(unconstify(void *, fp->u.zstd.input.src));
 		}
 
-		result = fclose(fp->zstd.fp);
-		fp->zstd.fp = NULL;
+		result = fclose(fp->u.zstd.fp);
+		fp->u.zstd.fp = NULL;
+		return result;
 	}
-	else
 #endif
-	{
-		result = fclose(fp->uncompressedfp);
-		fp->uncompressedfp = NULL;
-	}
-	free_keep_errno(fp);
 
+	result = fclose(fp->u.fp);
+	fp->u.fp = NULL;
+	free_keep_errno(fp);
 	return result;
 }
 
@@ -1057,26 +1050,26 @@ int
 cfeof(cfp *fp)
 {
 #ifdef HAVE_LIBZ
-	if (fp->compressedfp)
-		return gzeof(fp->compressedfp);
-	else
+	if (fp->alg == COMPR_ALG_LIBZ)
+		return gzeof(fp->u.gzfp);
 #endif
+
 #ifdef HAVE_LIBZSTD
-	if (fp->zstd.fp)
-		return feof(fp->zstd.fp);
-	else
+	if (fp->alg == COMPR_ALG_ZSTD)
+		return feof(fp->u.zstd.fp);
 #endif
-		return feof(fp->uncompressedfp);
+
+	return feof(fp->u.fp);
 }
 
 const char *
 get_cfp_error(cfp *fp)
 {
 #ifdef HAVE_LIBZ
-	if (fp->compressedfp)
+	if (fp->alg == COMPR_ALG_LIBZ)
 	{
 		int			errnum;
-		const char *errmsg = gzerror(fp->compressedfp, &errnum);
+		const char *errmsg = gzerror(fp->u.gzfp, &errnum);
 
 		if (errnum != Z_ERRNO)
 			return errmsg;
-- 
2.17.0

