From 44b3ed951859072b8d814d0439565187bf960b7b Mon Sep 17 00:00:00 2001
From: Justin Pryzby <pryzbyj@telsasoft.com>
Date: Mon, 21 Dec 2020 00:11:43 -0600
Subject: [PATCH 08/20] union{} with a CompressionAlgorithm alg

---
 src/bin/pg_dump/compress_io.c | 200 ++++++++++++++++++----------------
 src/bin/pg_dump/pg_dump.c     |   2 +-
 2 files changed, 106 insertions(+), 96 deletions(-)

diff --git a/src/bin/pg_dump/compress_io.c b/src/bin/pg_dump/compress_io.c
index fa94148cdf..e07436bc21 100644
--- a/src/bin/pg_dump/compress_io.c
+++ b/src/bin/pg_dump/compress_io.c
@@ -651,23 +651,27 @@ WriteDataToArchiveNone(ArchiveHandle *AH, CompressorState *cs,
  */
 struct cfp
 {
-	FILE	   *uncompressedfp;
+	CompressionAlgorithm alg;
+
+	union {
+		FILE	   *fp;
+
 #ifdef HAVE_LIBZ
-	gzFile		compressedfp;
+		gzFile		gzfp;
 #endif
 
-#ifdef HAVE_LIBZSTD // XXX: this should be a union with a CompressionAlgorithm alg?
-	/* This is a normal file to which we read/write compressed data */
-	struct {
-		FILE			*fp;
-		// XXX: use one separate ZSTD_CStream per thread: disable on windows ?
-		ZSTD_CStream	*cstream;
-		ZSTD_DStream	*dstream;
-		ZSTD_outBuffer	output;
-		ZSTD_inBuffer	input;
-	} zstd;
+#ifdef HAVE_LIBZSTD
+		struct {
+			/* This is a normal file to which we read/write compressed data */
+			FILE			*fp;
+			// XXX: use one separate ZSTD_CStream per thread: disable on windows ?
+			ZSTD_CStream	*cstream;
+			ZSTD_DStream	*dstream;
+			ZSTD_outBuffer	output;
+			ZSTD_inBuffer	input;
+		} zstd;
 #endif
-
+	} u;
 };
 
 static int	hasSuffix(const char *filename);
@@ -754,6 +758,8 @@ cfopen(const char *path, const char *mode, Compress *compression)
 {
 	cfp		   *fp = pg_malloc0(sizeof(cfp));
 
+	fp->alg = compression->alg;
+
 	switch (compression->alg)
 	{
 #ifdef HAVE_LIBZ
@@ -765,15 +771,15 @@ cfopen(const char *path, const char *mode, Compress *compression)
 
 			snprintf(mode_compression, sizeof(mode_compression), "%s%d",
 					 mode, compression->level);
-			fp->compressedfp = gzopen(path, mode_compression);
+			fp->u.gzfp = gzopen(path, mode_compression);
 		}
 		else
 		{
 			/* don't specify a level, just use the zlib default */
-			fp->compressedfp = gzopen(path, mode);
+			fp->u.gzfp = gzopen(path, mode);
 		}
 
-		if (fp->compressedfp == NULL)
+		if (fp->u.gzfp == NULL)
 		{
 			free_keep_errno(fp);
 			fp = NULL;
@@ -783,8 +789,8 @@ cfopen(const char *path, const char *mode, Compress *compression)
 
 #ifdef HAVE_LIBZSTD
 	case COMPR_ALG_ZSTD:
-		fp->zstd.fp = fopen(path, mode);
-		if (fp->zstd.fp == NULL)
+		fp->u.zstd.fp = fopen(path, mode);
+		if (fp->u.zstd.fp == NULL)
 		{
 			free_keep_errno(fp);
 			fp = NULL;
@@ -792,23 +798,23 @@ cfopen(const char *path, const char *mode, Compress *compression)
 		else if (mode[0] == 'w' || mode[0] == 'a' ||
 			strchr(mode, '+') != NULL)
 		{
-			fp->zstd.output.size = ZSTD_CStreamOutSize();
-			fp->zstd.output.dst = pg_malloc0(fp->zstd.output.size);
-			fp->zstd.cstream = ZstdCStreamParams(compression);
+			fp->u.zstd.output.size = ZSTD_CStreamOutSize();
+			fp->u.zstd.output.dst = pg_malloc0(fp->u.zstd.output.size);
+			fp->u.zstd.cstream = ZstdCStreamParams(compression);
 		}
 		else if (strchr(mode, 'r'))
 		{
-			fp->zstd.input.src = pg_malloc0(ZSTD_DStreamInSize());
-			fp->zstd.dstream = ZSTD_createDStream();
-			if (fp->zstd.dstream == NULL)
+			fp->u.zstd.input.src = pg_malloc0(ZSTD_DStreamInSize());
+			fp->u.zstd.dstream = ZSTD_createDStream();
+			if (fp->u.zstd.dstream == NULL)
 				fatal("could not initialize compression library");
 		} // XXX else: bad mode
 		return fp;
 #endif
 
 	case COMPR_ALG_NONE:
-		fp->uncompressedfp = fopen(path, mode);
-		if (fp->uncompressedfp == NULL)
+		fp->u.fp = fopen(path, mode);
+		if (fp->u.fp == NULL)
 		{
 			free_keep_errno(fp);
 			fp = NULL;
@@ -830,6 +836,8 @@ cfdopen(int fd, const char *mode, Compress *compression)
 {
 	cfp		   *fp = pg_malloc0(sizeof(cfp));
 
+	fp->alg = compression->alg;
+
 	switch (compression->alg)
 	{
 #ifdef HAVE_LIBZ
@@ -841,15 +849,15 @@ cfdopen(int fd, const char *mode, Compress *compression)
 
 			snprintf(mode_compression, sizeof(mode_compression), "%s%d",
 					 mode, compression->level);
-			fp->compressedfp = gzdopen(fd, mode_compression);
+			fp->u.gzfp = gzdopen(fd, mode_compression);
 		}
 		else
 		{
 			/* don't specify a level, just use the zlib default */
-			fp->compressedfp = gzdopen(fd, mode);
+			fp->u.gzfp = gzdopen(fd, mode);
 		}
 
-		if (fp->compressedfp == NULL)
+		if (fp->u.gzfp == NULL)
 		{
 			free_keep_errno(fp);
 			fp = NULL;
@@ -859,8 +867,8 @@ cfdopen(int fd, const char *mode, Compress *compression)
 
 #ifdef HAVE_LIBZSTD
 	case COMPR_ALG_ZSTD:
-		fp->zstd.fp = fdopen(fd, mode);
-		if (fp->zstd.fp == NULL)
+		fp->u.zstd.fp = fdopen(fd, mode);
+		if (fp->u.zstd.fp == NULL)
 		{
 			free_keep_errno(fp);
 			fp = NULL;
@@ -868,23 +876,23 @@ cfdopen(int fd, const char *mode, Compress *compression)
 		else if (mode[0] == 'w' || mode[0] == 'a' ||
 			strchr(mode, '+') != NULL)
 		{
-			fp->zstd.output.size = ZSTD_CStreamOutSize();
-			fp->zstd.output.dst = pg_malloc0(fp->zstd.output.size);
-			fp->zstd.cstream = ZstdCStreamParams(compression);
+			fp->u.zstd.output.size = ZSTD_CStreamOutSize();
+			fp->u.zstd.output.dst = pg_malloc0(fp->u.zstd.output.size);
+			fp->u.zstd.cstream = ZstdCStreamParams(compression);
 		}
 		else if (strchr(mode, 'r'))
 		{
-			fp->zstd.input.src = pg_malloc0(ZSTD_DStreamInSize());
-			fp->zstd.dstream = ZSTD_createDStream();
-			if (fp->zstd.dstream == NULL)
+			fp->u.zstd.input.src = pg_malloc0(ZSTD_DStreamInSize());
+			fp->u.zstd.dstream = ZSTD_createDStream();
+			if (fp->u.zstd.dstream == NULL)
 				fatal("could not initialize compression library");
 		} // XXX else: bad mode
 		return fp;
 #endif
 
 	case COMPR_ALG_NONE:
-		fp->uncompressedfp = fdopen(fd, mode);
-		if (fp->uncompressedfp == NULL)
+		fp->u.fp = fdopen(fd, mode);
+		if (fp->u.fp == NULL)
 		{
 			free_keep_errno(fp);
 			fp = NULL;
@@ -908,13 +916,13 @@ cfread(void *ptr, int size, cfp *fp)
 		return 0;
 
 #ifdef HAVE_LIBZ
-	if (fp->compressedfp)
+	if (fp->alg == COMPR_ALG_LIBZ)
 	{
-		ret = gzread(fp->compressedfp, ptr, size);
-		if (ret != size && !gzeof(fp->compressedfp))
+		ret = gzread(fp->u.gzfp, ptr, size);
+		if (ret != size && !gzeof(fp->u.gzfp))
 		{
 			int			errnum;
-			const char *errmsg = gzerror(fp->compressedfp, &errnum);
+			const char *errmsg = gzerror(fp->u.gzfp, &errnum);
 
 			fatal("could not read from input file: %s",
 				  errnum == Z_ERRNO ? strerror(errno) : errmsg);
@@ -924,10 +932,10 @@ cfread(void *ptr, int size, cfp *fp)
 #endif
 
 #ifdef HAVE_LIBZSTD
-	if (fp->zstd.fp)
+	if (fp->alg == COMPR_ALG_ZSTD)
 	{
-		ZSTD_outBuffer	*output = &fp->zstd.output;
-		ZSTD_inBuffer	*input = &fp->zstd.input;
+		ZSTD_outBuffer	*output = &fp->u.zstd.output;
+		ZSTD_inBuffer	*input = &fp->u.zstd.input;
 		size_t			input_size = ZSTD_DStreamInSize();
 		/* input_size is the allocated size */
 		size_t			res, cnt;
@@ -953,7 +961,7 @@ cfread(void *ptr, int size, cfp *fp)
 			/* read compressed data if we must produce more input */
 			if (input->pos == input->size)
 			{
-				cnt = fread(unconstify(void *, input->src), 1, input_size, fp->zstd.fp);
+				cnt = fread(unconstify(void *, input->src), 1, input_size, fp->u.zstd.fp);
 				input->size = cnt;
 
 				/* If we have no input to consume, we're done */
@@ -968,7 +976,7 @@ cfread(void *ptr, int size, cfp *fp)
 			for ( ; input->pos < input->size; )
 			{
 				/* decompress */
-				res = ZSTD_decompressStream(fp->zstd.dstream, output, input);
+				res = ZSTD_decompressStream(fp->u.zstd.dstream, output, input);
 				if (res == 0)
 					break; /* End of frame */
 				if (output->pos == output->size)
@@ -985,9 +993,9 @@ cfread(void *ptr, int size, cfp *fp)
 	}
 #endif
 
-	ret = fread(ptr, 1, size, fp->uncompressedfp);
-	if (ret != size && !feof(fp->uncompressedfp))
-		READ_ERROR_EXIT(fp->uncompressedfp);
+	ret = fread(ptr, 1, size, fp->u.fp);
+	if (ret != size && !feof(fp->u.fp))
+		READ_ERROR_EXIT(fp->u.fp);
 	return ret;
 }
 
@@ -995,16 +1003,16 @@ int
 cfwrite(const void *ptr, int size, cfp *fp)
 {
 #ifdef HAVE_LIBZ
-	if (fp->compressedfp)
-		return gzwrite(fp->compressedfp, ptr, size);
+	if (fp->alg == COMPR_ALG_LIBZ)
+		return gzwrite(fp->u.gzfp, ptr, size);
 #endif
 
 #ifdef HAVE_LIBZSTD
-	if (fp->zstd.fp)
+	if (fp->alg == COMPR_ALG_ZSTD)
 	{
 		size_t      res, cnt;
-		ZSTD_outBuffer	*output = &fp->zstd.output;
-		ZSTD_inBuffer	*input = &fp->zstd.input;
+		ZSTD_outBuffer	*output = &fp->u.zstd.output;
+		ZSTD_inBuffer	*input = &fp->u.zstd.input;
 
 		input->src = ptr;
 		input->size = size;
@@ -1014,11 +1022,11 @@ cfwrite(const void *ptr, int size, cfp *fp)
 		while (input->pos != input->size)
 		{
 			output->pos = 0;
-			res = ZSTD_compressStream2(fp->zstd.cstream, output, input, ZSTD_e_continue);
+			res = ZSTD_compressStream2(fp->u.zstd.cstream, output, input, ZSTD_e_continue);
 			if (ZSTD_isError(res))
 				fatal("could not compress data: %s", ZSTD_getErrorName(res));
 
-			cnt = fwrite(output->dst, 1, output->pos, fp->zstd.fp);
+			cnt = fwrite(output->dst, 1, output->pos, fp->u.zstd.fp);
 			if (cnt != output->pos)
 				fatal("could not write data: %s", strerror(errno));
 		}
@@ -1027,7 +1035,7 @@ cfwrite(const void *ptr, int size, cfp *fp)
 	}
 #endif
 
-	return fwrite(ptr, 1, size, fp->uncompressedfp);
+	return fwrite(ptr, 1, size, fp->u.fp);
 }
 
 int
@@ -1036,12 +1044,12 @@ cfgetc(cfp *fp)
 	int			ret;
 
 #ifdef HAVE_LIBZ
-	if (fp->compressedfp)
+	if (fp->alg == COMPR_ALG_LIBZ)
 	{
-		ret = gzgetc(fp->compressedfp);
+		ret = gzgetc(fp->u.gzfp);
 		if (ret == EOF)
 		{
-			if (!gzeof(fp->compressedfp))
+			if (!gzeof(fp->u.gzfp))
 				fatal("could not read from input file: %s", strerror(errno));
 			else
 				fatal("could not read from input file: end of file");
@@ -1051,11 +1059,11 @@ cfgetc(cfp *fp)
 #endif
 
 #ifdef HAVE_LIBZSTD
-	if (fp->zstd.fp)
+	if (fp->alg == COMPR_ALG_ZSTD)
 	{
 		if (cfread(&ret, 1, fp) != 1)
 		{
-			if (feof(fp->zstd.fp))
+			if (feof(fp->u.zstd.fp))
 				fatal("could not read from input file: end of file");
 			else
 				fatal("could not read from input file: %s", strerror(errno));
@@ -1064,9 +1072,9 @@ cfgetc(cfp *fp)
 	}
 #endif
 
-	ret = fgetc(fp->uncompressedfp);
+	ret = fgetc(fp->u.fp);
 	if (ret == EOF)
-		READ_ERROR_EXIT(fp->uncompressedfp);
+		READ_ERROR_EXIT(fp->u.fp);
 	return ret;
 }
 
@@ -1074,11 +1082,12 @@ char *
 cfgets(cfp *fp, char *buf, int len)
 {
 #ifdef HAVE_LIBZ
-	if (fp->compressedfp)
-		return gzgets(fp->compressedfp, buf, len);
+	if (fp->alg == COMPR_ALG_LIBZ)
+		return gzgets(fp->u.gzfp, buf, len);
 #endif
+
 #ifdef HAVE_LIBZSTD
-	if (fp->zstd.fp)
+	if (fp->alg == COMPR_ALG_ZSTD)
 	{
 		/*
 		 * Read one byte at a time until newline or EOF.
@@ -1102,7 +1111,7 @@ cfgets(cfp *fp, char *buf, int len)
 	}
 #endif
 
-	return fgets(buf, len, fp->uncompressedfp);
+	return fgets(buf, len, fp->u.fp);
 }
 
 /* Close the given compressed or uncompressed stream; return 0 on success. */
@@ -1117,54 +1126,54 @@ cfclose(cfp *fp)
 		return EOF;
 	}
 #ifdef HAVE_LIBZ
-	if (fp->compressedfp)
+	if (fp->alg == COMPR_ALG_LIBZ)
 	{
-		result = gzclose(fp->compressedfp);
-		fp->compressedfp = NULL;
+		result = gzclose(fp->u.gzfp);
+		fp->u.gzfp = NULL;
 		return result;
 	}
 #endif
 
 #ifdef HAVE_LIBZSTD
-	if (fp->zstd.fp)
+	if (fp->alg == COMPR_ALG_ZSTD)
 	{
-		ZSTD_outBuffer	*output = &fp->zstd.output;
-		ZSTD_inBuffer	*input = &fp->zstd.input;
+		ZSTD_outBuffer	*output = &fp->u.zstd.output;
+		ZSTD_inBuffer	*input = &fp->u.zstd.input;
 		size_t res, cnt;
 
-		if (fp->zstd.cstream)
+		if (fp->u.zstd.cstream)
 		{
 			for (;;)
 			{
 				output->pos = 0;
-				res = ZSTD_compressStream2(fp->zstd.cstream, output, input, ZSTD_e_end);
+				res = ZSTD_compressStream2(fp->u.zstd.cstream, output, input, ZSTD_e_end);
 				if (ZSTD_isError(res))
 					fatal("could not compress data: %s", ZSTD_getErrorName(res));
-				cnt = fwrite(output->dst, 1, output->pos, fp->zstd.fp);
+				cnt = fwrite(output->dst, 1, output->pos, fp->u.zstd.fp);
 				if (cnt != output->pos)
 					fatal("could not write data: %s", strerror(errno));
 				if (res == 0)
 					break;
 			}
 
-			ZSTD_freeCStream(fp->zstd.cstream);
-			pg_free(fp->zstd.output.dst);
+			ZSTD_freeCStream(fp->u.zstd.cstream);
+			pg_free(fp->u.zstd.output.dst);
 		}
 
-		if (fp->zstd.dstream)
+		if (fp->u.zstd.dstream)
 		{
-			ZSTD_freeDStream(fp->zstd.dstream);
-			pg_free(unconstify(void *, fp->zstd.input.src));
+			ZSTD_freeDStream(fp->u.zstd.dstream);
+			pg_free(unconstify(void *, fp->u.zstd.input.src));
 		}
 
-		result = fclose(fp->zstd.fp);
-		fp->zstd.fp = NULL;
+		result = fclose(fp->u.zstd.fp);
+		fp->u.zstd.fp = NULL;
 		return result;
 	}
 #endif
 
-	result = fclose(fp->uncompressedfp);
-	fp->uncompressedfp = NULL;
+	result = fclose(fp->u.fp);
+	fp->u.fp = NULL;
 	free_keep_errno(fp);
 	return result;
 }
@@ -1173,25 +1182,26 @@ int
 cfeof(cfp *fp)
 {
 #ifdef HAVE_LIBZ
-	if (fp->compressedfp)
-		return gzeof(fp->compressedfp);
+	if (fp->alg == COMPR_ALG_LIBZ)
+		return gzeof(fp->u.gzfp);
 #endif
 
 #ifdef HAVE_LIBZSTD
-	if (fp->zstd.fp)
-		return feof(fp->zstd.fp);
+	if (fp->alg == COMPR_ALG_ZSTD)
+		return feof(fp->u.zstd.fp);
 #endif
-	return feof(fp->uncompressedfp);
+
+	return feof(fp->u.fp);
 }
 
 const char *
 get_cfp_error(cfp *fp)
 {
 #ifdef HAVE_LIBZ
-	if (fp->compressedfp)
+	if (fp->alg == COMPR_ALG_LIBZ)
 	{
 		int			errnum;
-		const char *errmsg = gzerror(fp->compressedfp, &errnum);
+		const char *errmsg = gzerror(fp->u.gzfp, &errnum);
 
 		if (errnum != Z_ERRNO)
 			return errmsg;
diff --git a/src/bin/pg_dump/pg_dump.c b/src/bin/pg_dump/pg_dump.c
index 7c2f7a9ca3..5e009e5854 100644
--- a/src/bin/pg_dump/pg_dump.c
+++ b/src/bin/pg_dump/pg_dump.c
@@ -397,7 +397,7 @@ parse_compression(const char *optarg, Compress *compress)
 			const int default_compress_level[] = {
 				0,			/* COMPR_ALG_NONE */
 				Z_DEFAULT_COMPRESSION,	/* COMPR_ALG_ZLIB */
-				0, // XXX: ZSTD_CLEVEL_DEFAULT,	/* COMPR_ALG_ZSTD */
+				0, // #ifdef LIBZSTD ZSTD_CLEVEL_DEFAULT,	/* COMPR_ALG_ZSTD */
 			};
 
 			compress->level = default_compress_level[compress->alg];
-- 
2.17.0

