diff options
author | Patrick McDermott <patrick.mcdermott@libiquity.com> | 2019-07-29 16:11:35 (EDT) |
---|---|---|
committer | Patrick McDermott <patrick.mcdermott@libiquity.com> | 2019-07-29 16:11:35 (EDT) |
commit | dc6f59b5185ade4e36d95a46380ca40a67161a55 (patch) | |
tree | be1b6c1b99140b59af2a955b560dd239f951806d | |
parent | 22a7c2d07a52864cdba979fa442311c8f43d1ff6 (diff) |
s_client: Poll, read, and write FDs
-rw-r--r-- | src/s_client.c | 88 |
1 files changed, 88 insertions, 0 deletions
diff --git a/src/s_client.c b/src/s_client.c index 99fba35..863dd12 100644 --- a/src/s_client.c +++ b/src/s_client.c @@ -19,12 +19,15 @@ * along with wolfssl-util. If not, see <http://www.gnu.org/licenses/>. */ +#include <errno.h> #include <netdb.h> +#include <poll.h> #include <stdbool.h> #include <stdlib.h> #include <stdio.h> #include <string.h> #include <sys/socket.h> +#include <unistd.h> #include <wolfssl/ssl.h> #include <wolfssl/wolfcrypt/settings.h> @@ -33,6 +36,10 @@ #define CA_CERTS "/etc/ssl/certs" +#define ARRAY_SIZE(a) (sizeof(a) / sizeof((a)[0])) +#define MAX(a, b) (((a) > (b)) ? (a) : (b)) +#define MIN(a, b) (((a) < (b)) ? (a) : (b)) + static _Bool parse_host_port(char *hostport, char **host, char **port) { @@ -92,6 +99,83 @@ connect_socket(const char *host, const char *port) return sfd; } +static _Bool +write_all(int fd, const void *buf, size_t count) +{ + ssize_t ret; + + while (count > 0) { + while ((ret = write(fd, buf, count)) < 0 && errno == EINTR) { + continue; + } + if (ret < 0) { + return false; + } + buf = ((const char *) buf) + ret; + count -= ret; + } + + return true; +} + +static _Bool +poll_fds(int sfd, WOLFSSL *ssl) +{ + struct pollfd fds[2] = { + { .fd = -1, .events = POLLIN|POLLERR, .revents = 0 }, + { .fd = -1, .events = POLLIN|POLLERR, .revents = 0 }, + }; + char buf[MAX(8192, WOLFSSL_MAX_ERROR_SZ)]; + ssize_t len; + int ret; + + fds[0].fd = STDIN_FILENO; + fds[1].fd = sfd; + + for (;;) { + while (poll(fds, ARRAY_SIZE(fds), -1) < 0 && (errno == EINTR || + errno == EAGAIN)) { + continue; + } + if (fds[0].revents > 0) { /* stdin */ + len = read(STDIN_FILENO, buf, sizeof(buf)); + if (len < 0) { + fputs("Input read error\n", stderr); + return false; + } else if (len == 0) { + fds[0].fd = -1; /* Stop polling. */ + } else if ((ret = wolfSSL_write(ssl, buf, len)) <= 0) { + wolfSSL_ERR_error_string(wolfSSL_get_error(ssl, + ret), buf); + fprintf(stderr, "Socket write error: %s\n", + buf); + return false; + } + } + if (fds[1].revents > 0) { /* socket */ + ret = wolfSSL_read(ssl, buf, MIN(sizeof(buf), 1024)); + if (ret < 0) { + wolfSSL_ERR_error_string(wolfSSL_get_error(ssl, + ret), buf); + fprintf(stderr, "Socket read error: %s\n", + buf); + return false; + } else if (ret == 0) { + fds[1].fd = -1; /* Stop polling. */ + close(STDOUT_FILENO); /* Signal socket EOF. */ + } else if (write_all(STDOUT_FILENO, buf, (size_t) ret) + == false) { + return false; + } + } + if (fds[0].fd == fds[1].fd) { /* Both -1 (no longer polled) */ + return true; + } + } + + /* Unreached */ +} + int s_client(int argc, char **argv) { @@ -177,6 +261,10 @@ s_client(int argc, char **argv) goto ssl_free; } + if (poll_fds(sfd, ssl) == false) { + ret = EXIT_FAILURE; + } + close(sfd); ssl_free: |