summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorPatrick McDermott <patrick.mcdermott@libiquity.com>2019-07-29 16:11:35 (EDT)
committer Patrick McDermott <patrick.mcdermott@libiquity.com>2019-07-29 16:11:35 (EDT)
commitdc6f59b5185ade4e36d95a46380ca40a67161a55 (patch)
treebe1b6c1b99140b59af2a955b560dd239f951806d
parent22a7c2d07a52864cdba979fa442311c8f43d1ff6 (diff)
s_client: Poll, read, and write FDs
-rw-r--r--src/s_client.c88
1 files changed, 88 insertions, 0 deletions
diff --git a/src/s_client.c b/src/s_client.c
index 99fba35..863dd12 100644
--- a/src/s_client.c
+++ b/src/s_client.c
@@ -19,12 +19,15 @@
* along with wolfssl-util. If not, see <http://www.gnu.org/licenses/>.
*/
+#include <errno.h>
#include <netdb.h>
+#include <poll.h>
#include <stdbool.h>
#include <stdlib.h>
#include <stdio.h>
#include <string.h>
#include <sys/socket.h>
+#include <unistd.h>
#include <wolfssl/ssl.h>
#include <wolfssl/wolfcrypt/settings.h>
@@ -33,6 +36,10 @@
#define CA_CERTS "/etc/ssl/certs"
+#define ARRAY_SIZE(a) (sizeof(a) / sizeof((a)[0]))
+#define MAX(a, b) (((a) > (b)) ? (a) : (b))
+#define MIN(a, b) (((a) < (b)) ? (a) : (b))
+
static _Bool
parse_host_port(char *hostport, char **host, char **port)
{
@@ -92,6 +99,83 @@ connect_socket(const char *host, const char *port)
return sfd;
}
+static _Bool
+write_all(int fd, const void *buf, size_t count)
+{
+ ssize_t ret;
+
+ while (count > 0) {
+ while ((ret = write(fd, buf, count)) < 0 && errno == EINTR) {
+ continue;
+ }
+ if (ret < 0) {
+ return false;
+ }
+ buf = ((const char *) buf) + ret;
+ count -= ret;
+ }
+
+ return true;
+}
+
+static _Bool
+poll_fds(int sfd, WOLFSSL *ssl)
+{
+ struct pollfd fds[2] = {
+ { .fd = -1, .events = POLLIN|POLLERR, .revents = 0 },
+ { .fd = -1, .events = POLLIN|POLLERR, .revents = 0 },
+ };
+ char buf[MAX(8192, WOLFSSL_MAX_ERROR_SZ)];
+ ssize_t len;
+ int ret;
+
+ fds[0].fd = STDIN_FILENO;
+ fds[1].fd = sfd;
+
+ for (;;) {
+ while (poll(fds, ARRAY_SIZE(fds), -1) < 0 && (errno == EINTR ||
+ errno == EAGAIN)) {
+ continue;
+ }
+ if (fds[0].revents > 0) { /* stdin */
+ len = read(STDIN_FILENO, buf, sizeof(buf));
+ if (len < 0) {
+ fputs("Input read error\n", stderr);
+ return false;
+ } else if (len == 0) {
+ fds[0].fd = -1; /* Stop polling. */
+ } else if ((ret = wolfSSL_write(ssl, buf, len)) <= 0) {
+ wolfSSL_ERR_error_string(wolfSSL_get_error(ssl,
+ ret), buf);
+ fprintf(stderr, "Socket write error: %s\n",
+ buf);
+ return false;
+ }
+ }
+ if (fds[1].revents > 0) { /* socket */
+ ret = wolfSSL_read(ssl, buf, MIN(sizeof(buf), 1024));
+ if (ret < 0) {
+ wolfSSL_ERR_error_string(wolfSSL_get_error(ssl,
+ ret), buf);
+ fprintf(stderr, "Socket read error: %s\n",
+ buf);
+ return false;
+ } else if (ret == 0) {
+ fds[1].fd = -1; /* Stop polling. */
+ close(STDOUT_FILENO); /* Signal socket EOF. */
+ } else if (write_all(STDOUT_FILENO, buf, (size_t) ret)
+ == false) {
+ return false;
+ }
+ }
+ if (fds[0].fd == fds[1].fd) { /* Both -1 (no longer polled) */
+ return true;
+ }
+ }
+
+ /* Unreached */
+}
+
int
s_client(int argc, char **argv)
{
@@ -177,6 +261,10 @@ s_client(int argc, char **argv)
goto ssl_free;
}
+ if (poll_fds(sfd, ssl) == false) {
+ ret = EXIT_FAILURE;
+ }
+
close(sfd);
ssl_free: