aboutsummaryrefslogtreecommitdiff
path: root/comms.c
blob: 6946fd94acb63d03720b2c86e10c65b253923ba3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
/*******************************************************************************
 * Copyright (c) 2010 Linaro Limited
 * All rights reserved. This program and the accompanying materials
 * are made available under the terms of the Eclipse Public License v1.0
 * which accompanies this distribution, and is available at
 * http://www.eclipse.org/legal/epl-v10.html
 *
 * Contributors:
 *     Peter Maydell (Linaro) - initial implementation
 ******************************************************************************/

/* Routines for the socket communication between master and apprentice. */

#include <unistd.h>
#include <stdio.h>
#include <stdlib.h>
#include <errno.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <sys/uio.h>
#include <netinet/in.h>
#include <netdb.h>

#include "risu.h"

int apprentice_connect(const char *hostname, int port)
{
    /* We are the client end of the TCP connection */
    int sock;
    struct sockaddr_in sa;
    sock = socket(PF_INET, SOCK_STREAM, 0);
    if (sock < 0) {
        perror("socket");
        exit(1);
    }
    struct hostent *hostinfo;
    sa.sin_family = AF_INET;
    sa.sin_port = htons(port);
    hostinfo = gethostbyname(hostname);
    if (!hostinfo) {
        fprintf(stderr, "Unknown host %s\n", hostname);
        exit(1);
    }
    sa.sin_addr = *(struct in_addr *) hostinfo->h_addr;
    if (connect(sock, (struct sockaddr *) &sa, sizeof(sa)) < 0) {
        perror("connect");
        exit(1);
    }
    return sock;
}

int master_connect(int port)
{
    int sock;
    struct sockaddr_in sa;
    sock = socket(PF_INET, SOCK_STREAM, 0);
    if (sock < 0) {
        perror("socket");
        exit(1);
    }
    int sora = 1;
    if (setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &sora, sizeof(sora)) !=
        0) {
        perror("setsockopt(SO_REUSEADDR)");
        exit(1);
    }

    sa.sin_family = AF_INET;
    sa.sin_port = htons(port);
    sa.sin_addr.s_addr = htonl(INADDR_ANY);
    if (bind(sock, (struct sockaddr *) &sa, sizeof(sa)) < 0) {
        perror("bind");
        exit(1);
    }
    if (listen(sock, 1) < 0) {
        perror("listen");
        exit(1);
    }
    /* Just block until we get a connection */
    fprintf(stderr, "master: waiting for connection on port %d...\n",
            port);
    struct sockaddr_in csa;
    socklen_t csasz = sizeof(csa);
    int nsock = accept(sock, (struct sockaddr *) &csa, &csasz);
    if (nsock < 0) {
        perror("accept");
        exit(1);
    }
    /* We're done with the server socket now */
    close(sock);
    return nsock;
}

/* Utility functions which are just wrappers around read and writev
 * to catch errors and retry on short reads/writes.
 */
static void recv_bytes(int sock, void *pkt, int pktlen)
{
    char *p = pkt;
    while (pktlen) {
        int i = read(sock, p, pktlen);
        if (i <= 0) {
            if (errno == EINTR) {
                continue;
            }
            perror("read failed");
            exit(1);
        }
        pktlen -= i;
        p += i;
    }
}

static void recv_and_discard_bytes(int sock, int pktlen)
{
    /* Read and discard bytes */
    char dumpbuf[64];
    while (pktlen) {
        int i;
        int len = sizeof(dumpbuf);
        if (len > pktlen) {
            len = pktlen;
        }
        i = read(sock, dumpbuf, len);
        if (i <= 0) {
            if (errno == EINTR) {
                continue;
            }
            perror("read failed");
            exit(1);
        }
        pktlen -= i;
    }
}

ssize_t safe_writev(int fd, struct iovec *iov_in, int iovcnt)
{
    /* writev, retrying for EINTR and short writes */
    int r = 0;
    struct iovec *iov = iov_in;
    for (;;) {
        ssize_t i = writev(fd, iov, iovcnt);
        if (i == -1) {
            if (errno == EINTR) {
                continue;
            }
            return -1;
        }
        r += i;
        /* Move forward through iov to account for data transferred */
        while (i >= iov->iov_len) {
            i -= iov->iov_len;
            iov++;
            iovcnt--;
            if (iovcnt == 0) {
                return r;
            }
        }
        iov->iov_len -= i;
    }
}

/* Low level comms routines:
 * send_data_pkt sends a block of data and waits for
 * a single byte response code.
 * recv_data_pkt receives a block of data.
 * send_response_byte sends the response code.
 * Note that both ends must agree on the length of the
 * block of data.
 */
int send_data_pkt(int sock, void *pkt, int pktlen)
{
    unsigned char resp;
    /* First we send the packet length as a network-order 32 bit value.
     * This avoids silent deadlocks if the two sides disagree over
     * what size data packet they are transferring. We use writev()
     * so that both length and packet are sent in one packet; otherwise
     * we get 300x slowdown because we hit Nagle's algorithm.
     */
    uint32_t net_pktlen = htonl(pktlen);
    struct iovec iov[2];
    iov[0].iov_base = &net_pktlen;
    iov[0].iov_len = sizeof(net_pktlen);
    iov[1].iov_base = pkt;
    iov[1].iov_len = pktlen;

    if (safe_writev(sock, iov, 2) == -1) {
        perror("writev failed");
        exit(1);
    }

    if (read(sock, &resp, 1) != 1) {
        perror("read failed");
        exit(1);
    }
    return resp;
}

int recv_data_pkt(int sock, void *pkt, int pktlen)
{
    uint32_t net_pktlen;
    recv_bytes(sock, &net_pktlen, sizeof(net_pktlen));
    net_pktlen = ntohl(net_pktlen);
    if (pktlen != net_pktlen) {
        /* Mismatch. Read the data anyway so we can send
         * a response back.
         */
        recv_and_discard_bytes(sock, net_pktlen);
        return 1;
    }
    recv_bytes(sock, pkt, pktlen);
    return 0;
}

void send_response_byte(int sock, int resp)
{
    unsigned char r = resp;
    if (write(sock, &r, 1) != 1) {
        perror("write failed");
        exit(1);
    }
}