diff --git a/uppsrc/Core/Http.cpp b/uppsrc/Core/Http.cpp index ffeb273b3..c696dc85f 100644 --- a/uppsrc/Core/Http.cpp +++ b/uppsrc/Core/Http.cpp @@ -1,685 +1,685 @@ -#include "Core.h" - -NAMESPACE_UPP - -bool HttpRequest_Trace__; - -#define LLOG(x) do { if(HttpRequest_Trace__) RLOG(x); } while(0) - -#ifdef _DEBUG -_DBG_ -// #define ENDZIP -#endif - -void HttpRequest::Trace(bool b) -{ - HttpRequest_Trace__ = b; -} - -void HttpRequest::Init() -{ - port = 0; - proxy_port = 0; - ssl_proxy_port = 0; - max_header_size = 1000000; - max_content_size = 10000000; - max_redirects = 5; - max_retries = 3; - force_digest = false; - std_headers = true; - hasurlvar = false; - method = METHOD_GET; - phase = BEGIN; - redirect_count = 0; - retry_count = 0; - gzip = false; - WhenContent = callback(this, &HttpRequest::ContentOut); - chunk = 4096; - timeout = 120000; - ssl = false; -} - -HttpRequest::HttpRequest() -{ - Init(); -} - -HttpRequest::HttpRequest(const char *url) -{ - Init(); - Url(url); -} - -HttpRequest& HttpRequest::Url(const char *u) -{ - ssl = memcmp(u, "https", 5) == 0; - const char *t = u; - while(*t && *t != '?') - if(*t++ == '/' && *t == '/') { - u = ++t; - break; - } - t = u; - while(*u && *u != ':' && *u != '/' && *u != '?') - u++; - if(*u == '?' && u[1]) - hasurlvar = true; - host = String(t, u); - port = 0; - if(*u == ':') - port = ScanInt(u + 1, &u); - path = u; - int q = path.Find('#'); - if(q >= 0) - path.Trim(q); - return *this; -} - -static -void sParseProxyUrl(const char *p, String& proxy_host, int proxy_port) -{ - const char *t = p; - while(*p && *p != ':') - p++; - proxy_host = String(t, p); - if(*p++ == ':' && IsDigit(*p)) - proxy_port = ScanInt(p); -} - -HttpRequest& HttpRequest::Proxy(const char *url) -{ - proxy_port = 80; - sParseProxyUrl(url, proxy_host, proxy_port); - return *this; -} - -HttpRequest& HttpRequest::SSLProxy(const char *url) -{ - ssl_proxy_port = 8080; - sParseProxyUrl(url, ssl_proxy_host, ssl_proxy_port); - return *this; -} - -HttpRequest& HttpRequest::Post(const char *id, const String& data) -{ - POST(); - if(postdata.GetCount()) - postdata << '&'; - postdata << id << '=' << UrlEncode(data); - return *this; -} - -HttpRequest& HttpRequest::UrlVar(const char *id, const String& data) -{ - int c = *path.Last(); - if(hasurlvar && c != '&') - path << '&'; - if(!hasurlvar && c != '?') - path << '?'; - path << id << '=' << UrlEncode(data); - hasurlvar = true; - return *this; -} - -String HttpRequest::CalculateDigest(const String& authenticate) const -{ - const char *p = authenticate; - String realm, qop, nonce, opaque; - while(*p) { - if(!IsAlNum(*p)) { - p++; - continue; - } - else { - const char *b = p; - while(IsAlNum(*p)) - p++; - String var = ToLower(String(b, p)); - String value; - while(*p && (byte)*p <= ' ') - p++; - if(*p == '=') { - p++; - while(*p && (byte)*p <= ' ') - p++; - if(*p == '\"') { - p++; - while(*p && *p != '\"') - if(*p != '\\' || *++p) - value.Cat(*p++); - if(*p == '\"') - p++; - } - else { - b = p; - while(*p && *p != ',' && (byte)*p > ' ') - p++; - value = String(b, p); - } - } - if(var == "realm") - realm = value; - else if(var == "qop") - qop = value; - else if(var == "nonce") - nonce = value; - else if(var == "opaque") - opaque = value; - } - } - String hv1, hv2; - hv1 << username << ':' << realm << ':' << password; - String ha1 = MD5String(hv1); - hv2 << (method == METHOD_GET ? "GET" : method == METHOD_PUT ? "PUT" : method == METHOD_POST ? "POST" : "READ") - << ':' << path; - String ha2 = MD5String(hv2); - int nc = 1; - String cnonce = FormatIntHex(Random(), 8); - String hv; - hv << ha1 - << ':' << nonce - << ':' << FormatIntHex(nc, 8) - << ':' << cnonce - << ':' << qop << ':' << ha2; - String ha = MD5String(hv); - String auth; - auth << "username=" << AsCString(username) - << ", realm=" << AsCString(realm) - << ", nonce=" << AsCString(nonce) - << ", uri=" << AsCString(path) - << ", qop=" << AsCString(qop) - << ", nc=" << AsCString(FormatIntHex(nc, 8)) - << ", cnonce=" << cnonce - << ", response=" << AsCString(ha); - if(!IsNull(opaque)) - auth << ", opaque=" << AsCString(opaque); - return auth; -} - -HttpRequest& HttpRequest::Header(const char *id, const String& data) -{ - request_headers << id << ": " << data << "\r\n"; - return *this; -} - -void HttpRequest::HttpError(const char *s) -{ - if(IsError()) - return; - error = NFormat(t_("%s:%d: ") + String(s), host, port); - LLOG("HTTP ERROR: " << error); - Close(); -} - -void HttpRequest::StartPhase(int s) -{ - phase = s; - LLOG("Starting status " << s << " '" << GetPhaseName() << "', url: " << host); - data.Clear(); -} - -bool HttpRequest::Do() -{ - int c1, c2; - switch(phase) { - case BEGIN: - retry_count = 0; - redirect_count = 0; - start_time = msecs(); - case START: - Start(); - break; - case DNS: - Dns(); - break; - case SSLPROXYREQUEST: - if(SendingData()) - break; - StartPhase(SSLPROXYRESPONSE); - break; - case SSLPROXYRESPONSE: - if(ReadingHeader()) - break; - ProcessSSLProxyResponse(); - break; - case SSLHANDSHAKE: - if(SSLHandshake()) - break; - StartRequest(); - break; - case REQUEST: - if(SendingData()) - break; - StartPhase(HEADER); - break; - case HEADER: - if(ReadingHeader()) - break; - StartBody(); - break; - case BODY: - if(ReadingBody()) - break; - Finish(); - break; - case CHUNK_HEADER: - ReadingChunkHeader(); - break; - case CHUNK_BODY: - if(ReadingBody()) - break; - c1 = Get(); - c2 = Get(); - if(c1 != '\r' || c2 != '\n') - HttpError("missing ending CRLF in chunked transfer"); - StartPhase(CHUNK_HEADER); - break; - case TRAILER: - if(ReadingHeader()) - break; - header.Parse(data); - Finish(); - break; - case FINISHED: - case FAILED: - return false; - default: - NEVER(); - } - - if(phase != FAILED) - if(IsSocketError() || IsError()) - phase = FAILED; - else - if(msecs() - start_time >= timeout) { - HttpError("connection timed out"); - phase = FAILED; - } - else - if(IsAbort()) { - HttpError("connection was aborted"); - phase = FAILED; - } - - if(phase == FAILED) { - if(retry_count++ < max_retries) { - LLOG("HTTP retry on error " << GetErrorDesc()); - start_time = msecs(); - StartPhase(START); - } - } - return phase != FINISHED && phase != FAILED; -} - -void HttpRequest::Start() -{ - Close(); - ClearError(); - gzip = false; - z.Clear(); - header.Clear(); - - bool use_proxy = !IsNull(ssl ? ssl_proxy_host : proxy_host); - - int p = use_proxy ? (ssl ? ssl_proxy_port : proxy_port) : port; - if(!p) - p = ssl ? DEFAULT_HTTPS_PORT : DEFAULT_HTTP_PORT; - String h = use_proxy ? ssl ? ssl_proxy_host : proxy_host : host; - - if(IsNull(GetTimeout())) { - addrinfo.Execute(h, p); - StartConnect(); - } - else { - addrinfo.Start(h, p); - StartPhase(DNS); - } -} - -void HttpRequest::Dns() -{ - for(int i = 0; i <= Nvl(GetTimeout(), INT_MAX); i++) { - if(!addrinfo.InProgress()) { - StartConnect(); - return; - } - Sleep(1); - } -} - -void HttpRequest::StartConnect() -{ - if(!Connect(addrinfo)) - return; - if(ssl && ssl_proxy_host.GetCount()) { - StartPhase(SSLPROXYREQUEST); - String host_port = host; - if(port) - host_port << ':' << port; - else - host_port << ":443"; - data << "CONNECT " << host_port << " HTTP/1.1\r\n" - << "Host: " << host_port << "\r\n"; - if(!IsNull(ssl_proxy_username)) - data << "Proxy-Authorization: Basic " - << Base64Encode(proxy_username + ':' + proxy_password) << "\r\n"; - data << "\r\n"; - count = 0; - LLOG("HTTPS proxy request:\n" << data); - } - else - AfterConnect(); -} - -void HttpRequest::ProcessSSLProxyResponse() -{ - LLOG("HTTPS proxy response:\n" << data); - int q = min(data.Find('\r'), data.Find('\n')); - if(q >= 0) - data.Trim(q); - if(!data.StartsWith("HTTP") || data.Find(" 2") < 0) { - HttpError("Invalid proxy reply: " + data); - return; - } - AfterConnect(); -} - -void HttpRequest::AfterConnect() -{ - if(ssl && !StartSSL()) - return; - if(ssl) - StartPhase(SSLHANDSHAKE); - else - StartRequest(); -} - -void HttpRequest::StartRequest() -{ - StartPhase(REQUEST); - count = 0; - String ctype = contenttype; - if((method == METHOD_POST || method == METHOD_PUT) && IsNull(ctype)) - ctype = "application/x-www-form-urlencoded"; - switch(method) { - case METHOD_GET: data << "GET "; break; - case METHOD_POST: data << "POST "; break; - case METHOD_PUT: data << "PUT "; break; - case METHOD_HEAD: data << "HEAD "; break; - default: NEVER(); // invalid method - } - String host_port = host; - if(port) - host_port << ':' << port; - String url; - url << "http://" << host_port << Nvl(path, "/"); - if(!IsNull(proxy_host) && !ssl) - data << url; - else - data << Nvl(path, "/"); - data << " HTTP/1.1\r\n"; - if(std_headers) { - data << "URL: " << url << "\r\n" - << "Host: " << host_port << "\r\n" - << "Connection: close\r\n" - << "Accept: " << Nvl(accept, "*/*") << "\r\n" - << "Accept-Encoding: gzip\r\n" - << "User-Agent: " << Nvl(agent, "Ultimate++ HTTP client") << "\r\n"; - if(postdata.GetCount()) - data << "Content-Length: " << postdata.GetCount() << "\r\n"; - if(ctype.GetCount()) - data << "Content-Type: " << ctype << "\r\n"; - } - if(!IsNull(proxy_host) && !IsNull(proxy_username)) - data << "Proxy-Authorization: Basic " << Base64Encode(proxy_username + ':' + proxy_password) << "\r\n"; - if(!IsNull(digest)) - data << "Authorization: Digest " << digest << "\r\n"; - else - if(!force_digest && (!IsNull(username) || !IsNull(password))) - data << "Authorization: Basic " << Base64Encode(username + ":" + password) << "\r\n"; - data << request_headers << "\r\n" << postdata; // !!! POST PHASE !!! - LLOG("HTTP REQUEST " << host << ":" << port); - LLOG("HTTP request:\n" << data); -} - -bool HttpRequest::SendingData() -{ - for(;;) { - int n = min(2048, data.GetLength() - count); - n = Put(~data + count, n); - if(n == 0) - break; - count += n; - } - return count < data.GetLength(); -} - -bool HttpRequest::ReadingHeader() -{ - for(;;) { - int c = Get(); - if(c < 0) - return !IsEof(); - else - data.Cat(c); - if(data.GetCount() > 3) { - const char *h = data.Last(); - if(h[0] == '\n' && (h[-1] == '\r' && h[-2] == '\n' || h[-1] == '\n')) - return false; - } - if(data.GetCount() > max_header_size) { - HttpError("HTTP header exceeded " + AsString(max_header_size)); - return true; - } - } -} - -void HttpRequest::ReadingChunkHeader() -{ - for(;;) { - int c = Get(); - if(c < 0) - break; - else - if(c == '\n') { - int n = ScanInt(~data, NULL, 16); - LLOG("HTTP Chunk header: 0x" << data << " = " << n); - if(IsNull(n)) { - HttpError("invalid chunk header"); - break; - } - if(n == 0) { - StartPhase(TRAILER); - break; - } - count += n; - StartPhase(CHUNK_BODY); - break; - } - if(c != '\r') - data.Cat(c); - } -} - -String HttpRequest::GetRedirectUrl() -{ - String redirect_url = TrimLeft(header["location"]); - if(redirect_url.StartsWith("http://") || redirect_url.StartsWith("https://")) - return redirect_url; - String h = (ssl ? "https://" : "http://") + host; - if(*redirect_url != '/') - h << '/'; - h << redirect_url; - return h; -} - -int HttpRequest::GetContentLength() -{ - return Nvl(ScanInt(header["content-length"]), -1); -} - -void HttpRequest::StartBody() -{ - LLOG("HTTP Header received: "); - LLOG(data); - header.Clear(); - if(!header.Parse(data)) { - HttpError("invalid HTTP header"); - return; - } - - if(!header.Response(protocol, status_code, reason_phrase)) { - HttpError("invalid HTTP response"); - return; - } - - LLOG("HTTP status code: " << status_code); - - count = GetContentLength(); - - if(count > 0) - body.Reserve(count); - - if(method == METHOD_HEAD) - phase = FINISHED; - else - if(header["transfer-encoding"] == "chunked") { - count = 0; - StartPhase(CHUNK_HEADER); - } - else - StartPhase(BODY); - body.Clear(); - bodylen = 0; - gzip = GetHeader("content-encoding") == "gzip"; - if(gzip) { - gzip = true; - z.WhenOut = callback(this, &HttpRequest::Out); - z.ChunkSize(chunk).GZip().Decompress(); - } -} - -void HttpRequest::ContentOut(const void *ptr, dword size) -{ - body.Cat((const char *)ptr, size); -} - -void HttpRequest::Out(const void *ptr, dword size) -{ - LLOG("HTTP Out " << size); - if(z.IsError()) - HttpError("gzip format error"); - int64 l = bodylen + size; - if(l > max_content_size) { - HttpError("content length exceeded " + AsString(max_content_size)); - phase = FAILED; - return; - } - WhenContent(ptr, size); - bodylen += size; -} - -bool HttpRequest::ReadingBody() -{ - LLOG("HTTP reading data " << count); - int n = chunk; - if(count >= 0) - n = min(n, count); - String s = Get(n); - if(s.GetCount() == 0) - return !IsEof() && count; -#ifndef ENDZIP - if(gzip) - z.Put(~s, s.GetCount()); - else -#endif - Out(~s, s.GetCount()); - if(count >= 0) { - count -= s.GetCount(); - return !IsEof() && count > 0; - } - return !IsEof(); -} - -void HttpRequest::CopyCookies() -{ - int q = header.fields.Find("set-cookie"); - while(q >= 0) { - Cookie(header.fields[q]); - q = header.fields.FindNext(q); - } -} - -void HttpRequest::Finish() -{ - if(gzip) { - #ifdef ENDZIP - body = GZDecompress(body); - if(body.IsVoid()) { - HttpError("gzip decompress at finish error"); - phase = FAILED; - return; - } - #else - z.End(); - if(z.IsError()) { - HttpError("gzip format error (finish)"); - phase = FAILED; - return; - } - #endif - } - Close(); - if(status_code == 401 && !IsNull(username)) { - String authenticate = header["www-authenticate"]; - if(authenticate.GetCount() && redirect_count++ < max_redirects) { - LLOG("HTTP auth digest"); - CopyCookies(); - Digest(CalculateDigest(authenticate)); - Start(); - return; - } - } - if(status_code >= 300 && status_code < 400) { - String url = GetRedirectUrl(); - if(url.GetCount() && redirect_count++ < max_redirects) { - LLOG("HTTP redirect " << url); - Url(url); - CopyCookies(); - Start(); - retry_count = 0; - return; - } - } - phase = FINISHED; -} - -String HttpRequest::Execute() -{ - while(Do()) - LLOG("HTTP Execute: " << GetPhaseName()); - return IsSuccess() ? GetContent() : String::GetVoid(); -} - -String HttpRequest::GetPhaseName() const -{ - static const char *m[] = { - "Initial state", - "Start", - "Resolving host name", - "SSL proxy request", - "SSL proxy response", - "SSL handshake", - "Sending request", - "Receiving header", - "Receiving content", - "Receiving chunk header", - "Receiving content chunk", - "Receiving trailer", - "Finished", - "Failed", - }; - return phase >= 0 && phase <= FAILED ? m[phase] : ""; -} - -END_UPP_NAMESPACE +#include "Core.h" + +NAMESPACE_UPP + +bool HttpRequest_Trace__; + +#define LLOG(x) do { if(HttpRequest_Trace__) RLOG(x); } while(0) + +#ifdef _DEBUG +_DBG_ +// #define ENDZIP +#endif + +void HttpRequest::Trace(bool b) +{ + HttpRequest_Trace__ = b; +} + +void HttpRequest::Init() +{ + port = 0; + proxy_port = 0; + ssl_proxy_port = 0; + max_header_size = 1000000; + max_content_size = 10000000; + max_redirects = 5; + max_retries = 3; + force_digest = false; + std_headers = true; + hasurlvar = false; + method = METHOD_GET; + phase = BEGIN; + redirect_count = 0; + retry_count = 0; + gzip = false; + WhenContent = callback(this, &HttpRequest::ContentOut); + chunk = 4096; + timeout = 120000; + ssl = false; +} + +HttpRequest::HttpRequest() +{ + Init(); +} + +HttpRequest::HttpRequest(const char *url) +{ + Init(); + Url(url); +} + +HttpRequest& HttpRequest::Url(const char *u) +{ + ssl = memcmp(u, "https", 5) == 0; + const char *t = u; + while(*t && *t != '?') + if(*t++ == '/' && *t == '/') { + u = ++t; + break; + } + t = u; + while(*u && *u != ':' && *u != '/' && *u != '?') + u++; + if(*u == '?' && u[1]) + hasurlvar = true; + host = String(t, u); + port = 0; + if(*u == ':') + port = ScanInt(u + 1, &u); + path = u; + int q = path.Find('#'); + if(q >= 0) + path.Trim(q); + return *this; +} + +static +void sParseProxyUrl(const char *p, String& proxy_host, int proxy_port) +{ + const char *t = p; + while(*p && *p != ':') + p++; + proxy_host = String(t, p); + if(*p++ == ':' && IsDigit(*p)) + proxy_port = ScanInt(p); +} + +HttpRequest& HttpRequest::Proxy(const char *url) +{ + proxy_port = 80; + sParseProxyUrl(url, proxy_host, proxy_port); + return *this; +} + +HttpRequest& HttpRequest::SSLProxy(const char *url) +{ + ssl_proxy_port = 8080; + sParseProxyUrl(url, ssl_proxy_host, ssl_proxy_port); + return *this; +} + +HttpRequest& HttpRequest::Post(const char *id, const String& data) +{ + POST(); + if(postdata.GetCount()) + postdata << '&'; + postdata << id << '=' << UrlEncode(data); + return *this; +} + +HttpRequest& HttpRequest::UrlVar(const char *id, const String& data) +{ + int c = *path.Last(); + if(hasurlvar && c != '&') + path << '&'; + if(!hasurlvar && c != '?') + path << '?'; + path << id << '=' << UrlEncode(data); + hasurlvar = true; + return *this; +} + +String HttpRequest::CalculateDigest(const String& authenticate) const +{ + const char *p = authenticate; + String realm, qop, nonce, opaque; + while(*p) { + if(!IsAlNum(*p)) { + p++; + continue; + } + else { + const char *b = p; + while(IsAlNum(*p)) + p++; + String var = ToLower(String(b, p)); + String value; + while(*p && (byte)*p <= ' ') + p++; + if(*p == '=') { + p++; + while(*p && (byte)*p <= ' ') + p++; + if(*p == '\"') { + p++; + while(*p && *p != '\"') + if(*p != '\\' || *++p) + value.Cat(*p++); + if(*p == '\"') + p++; + } + else { + b = p; + while(*p && *p != ',' && (byte)*p > ' ') + p++; + value = String(b, p); + } + } + if(var == "realm") + realm = value; + else if(var == "qop") + qop = value; + else if(var == "nonce") + nonce = value; + else if(var == "opaque") + opaque = value; + } + } + String hv1, hv2; + hv1 << username << ':' << realm << ':' << password; + String ha1 = MD5String(hv1); + hv2 << (method == METHOD_GET ? "GET" : method == METHOD_PUT ? "PUT" : method == METHOD_POST ? "POST" : "READ") + << ':' << path; + String ha2 = MD5String(hv2); + int nc = 1; + String cnonce = FormatIntHex(Random(), 8); + String hv; + hv << ha1 + << ':' << nonce + << ':' << FormatIntHex(nc, 8) + << ':' << cnonce + << ':' << qop << ':' << ha2; + String ha = MD5String(hv); + String auth; + auth << "username=" << AsCString(username) + << ", realm=" << AsCString(realm) + << ", nonce=" << AsCString(nonce) + << ", uri=" << AsCString(path) + << ", qop=" << AsCString(qop) + << ", nc=" << AsCString(FormatIntHex(nc, 8)) + << ", cnonce=" << cnonce + << ", response=" << AsCString(ha); + if(!IsNull(opaque)) + auth << ", opaque=" << AsCString(opaque); + return auth; +} + +HttpRequest& HttpRequest::Header(const char *id, const String& data) +{ + request_headers << id << ": " << data << "\r\n"; + return *this; +} + +void HttpRequest::HttpError(const char *s) +{ + if(IsError()) + return; + error = NFormat(t_("%s:%d: ") + String(s), host, port); + LLOG("HTTP ERROR: " << error); + Close(); +} + +void HttpRequest::StartPhase(int s) +{ + phase = s; + LLOG("Starting status " << s << " '" << GetPhaseName() << "', url: " << host); + data.Clear(); +} + +bool HttpRequest::Do() +{ + int c1, c2; + switch(phase) { + case BEGIN: + retry_count = 0; + redirect_count = 0; + start_time = msecs(); + case START: + Start(); + break; + case DNS: + Dns(); + break; + case SSLPROXYREQUEST: + if(SendingData()) + break; + StartPhase(SSLPROXYRESPONSE); + break; + case SSLPROXYRESPONSE: + if(ReadingHeader()) + break; + ProcessSSLProxyResponse(); + break; + case SSLHANDSHAKE: + if(SSLHandshake()) + break; + StartRequest(); + break; + case REQUEST: + if(SendingData()) + break; + StartPhase(HEADER); + break; + case HEADER: + if(ReadingHeader()) + break; + StartBody(); + break; + case BODY: + if(ReadingBody()) + break; + Finish(); + break; + case CHUNK_HEADER: + ReadingChunkHeader(); + break; + case CHUNK_BODY: + if(ReadingBody()) + break; + c1 = Get(); + c2 = Get(); + if(c1 != '\r' || c2 != '\n') + HttpError("missing ending CRLF in chunked transfer"); + StartPhase(CHUNK_HEADER); + break; + case TRAILER: + if(ReadingHeader()) + break; + header.Parse(data); + Finish(); + break; + case FINISHED: + case FAILED: + return false; + default: + NEVER(); + } + + if(phase != FAILED) + if(IsSocketError() || IsError()) + phase = FAILED; + else + if(msecs(start_time) >= timeout) { + HttpError("connection timed out"); + phase = FAILED; + } + else + if(IsAbort()) { + HttpError("connection was aborted"); + phase = FAILED; + } + + if(phase == FAILED) { + if(retry_count++ < max_retries) { + LLOG("HTTP retry on error " << GetErrorDesc()); + start_time = msecs(); + StartPhase(START); + } + } + return phase != FINISHED && phase != FAILED; +} + +void HttpRequest::Start() +{ + Close(); + ClearError(); + gzip = false; + z.Clear(); + header.Clear(); + + bool use_proxy = !IsNull(ssl ? ssl_proxy_host : proxy_host); + + int p = use_proxy ? (ssl ? ssl_proxy_port : proxy_port) : port; + if(!p) + p = ssl ? DEFAULT_HTTPS_PORT : DEFAULT_HTTP_PORT; + String h = use_proxy ? ssl ? ssl_proxy_host : proxy_host : host; + + if(IsNull(GetTimeout())) { + addrinfo.Execute(h, p); + StartConnect(); + } + else { + addrinfo.Start(h, p); + StartPhase(DNS); + } +} + +void HttpRequest::Dns() +{ + for(int i = 0; i <= Nvl(GetTimeout(), INT_MAX); i++) { + if(!addrinfo.InProgress()) { + StartConnect(); + return; + } + Sleep(1); + } +} + +void HttpRequest::StartConnect() +{ + if(!Connect(addrinfo)) + return; + if(ssl && ssl_proxy_host.GetCount()) { + StartPhase(SSLPROXYREQUEST); + String host_port = host; + if(port) + host_port << ':' << port; + else + host_port << ":443"; + data << "CONNECT " << host_port << " HTTP/1.1\r\n" + << "Host: " << host_port << "\r\n"; + if(!IsNull(ssl_proxy_username)) + data << "Proxy-Authorization: Basic " + << Base64Encode(proxy_username + ':' + proxy_password) << "\r\n"; + data << "\r\n"; + count = 0; + LLOG("HTTPS proxy request:\n" << data); + } + else + AfterConnect(); +} + +void HttpRequest::ProcessSSLProxyResponse() +{ + LLOG("HTTPS proxy response:\n" << data); + int q = min(data.Find('\r'), data.Find('\n')); + if(q >= 0) + data.Trim(q); + if(!data.StartsWith("HTTP") || data.Find(" 2") < 0) { + HttpError("Invalid proxy reply: " + data); + return; + } + AfterConnect(); +} + +void HttpRequest::AfterConnect() +{ + if(ssl && !StartSSL()) + return; + if(ssl) + StartPhase(SSLHANDSHAKE); + else + StartRequest(); +} + +void HttpRequest::StartRequest() +{ + StartPhase(REQUEST); + count = 0; + String ctype = contenttype; + if((method == METHOD_POST || method == METHOD_PUT) && IsNull(ctype)) + ctype = "application/x-www-form-urlencoded"; + switch(method) { + case METHOD_GET: data << "GET "; break; + case METHOD_POST: data << "POST "; break; + case METHOD_PUT: data << "PUT "; break; + case METHOD_HEAD: data << "HEAD "; break; + default: NEVER(); // invalid method + } + String host_port = host; + if(port) + host_port << ':' << port; + String url; + url << "http://" << host_port << Nvl(path, "/"); + if(!IsNull(proxy_host) && !ssl) + data << url; + else + data << Nvl(path, "/"); + data << " HTTP/1.1\r\n"; + if(std_headers) { + data << "URL: " << url << "\r\n" + << "Host: " << host_port << "\r\n" + << "Connection: close\r\n" + << "Accept: " << Nvl(accept, "*/*") << "\r\n" + << "Accept-Encoding: gzip\r\n" + << "User-Agent: " << Nvl(agent, "Ultimate++ HTTP client") << "\r\n"; + if(postdata.GetCount()) + data << "Content-Length: " << postdata.GetCount() << "\r\n"; + if(ctype.GetCount()) + data << "Content-Type: " << ctype << "\r\n"; + } + if(!IsNull(proxy_host) && !IsNull(proxy_username)) + data << "Proxy-Authorization: Basic " << Base64Encode(proxy_username + ':' + proxy_password) << "\r\n"; + if(!IsNull(digest)) + data << "Authorization: Digest " << digest << "\r\n"; + else + if(!force_digest && (!IsNull(username) || !IsNull(password))) + data << "Authorization: Basic " << Base64Encode(username + ":" + password) << "\r\n"; + data << request_headers << "\r\n" << postdata; // !!! POST PHASE !!! + LLOG("HTTP REQUEST " << host << ":" << port); + LLOG("HTTP request:\n" << data); +} + +bool HttpRequest::SendingData() +{ + for(;;) { + int n = min(2048, data.GetLength() - count); + n = Put(~data + count, n); + if(n == 0) + break; + count += n; + } + return count < data.GetLength(); +} + +bool HttpRequest::ReadingHeader() +{ + for(;;) { + int c = Get(); + if(c < 0) + return !IsEof(); + else + data.Cat(c); + if(data.GetCount() > 3) { + const char *h = data.Last(); + if(h[0] == '\n' && (h[-1] == '\r' && h[-2] == '\n' || h[-1] == '\n')) + return false; + } + if(data.GetCount() > max_header_size) { + HttpError("HTTP header exceeded " + AsString(max_header_size)); + return true; + } + } +} + +void HttpRequest::ReadingChunkHeader() +{ + for(;;) { + int c = Get(); + if(c < 0) + break; + else + if(c == '\n') { + int n = ScanInt(~data, NULL, 16); + LLOG("HTTP Chunk header: 0x" << data << " = " << n); + if(IsNull(n)) { + HttpError("invalid chunk header"); + break; + } + if(n == 0) { + StartPhase(TRAILER); + break; + } + count += n; + StartPhase(CHUNK_BODY); + break; + } + if(c != '\r') + data.Cat(c); + } +} + +String HttpRequest::GetRedirectUrl() +{ + String redirect_url = TrimLeft(header["location"]); + if(redirect_url.StartsWith("http://") || redirect_url.StartsWith("https://")) + return redirect_url; + String h = (ssl ? "https://" : "http://") + host; + if(*redirect_url != '/') + h << '/'; + h << redirect_url; + return h; +} + +int HttpRequest::GetContentLength() +{ + return Nvl(ScanInt(header["content-length"]), -1); +} + +void HttpRequest::StartBody() +{ + LLOG("HTTP Header received: "); + LLOG(data); + header.Clear(); + if(!header.Parse(data)) { + HttpError("invalid HTTP header"); + return; + } + + if(!header.Response(protocol, status_code, reason_phrase)) { + HttpError("invalid HTTP response"); + return; + } + + LLOG("HTTP status code: " << status_code); + + count = GetContentLength(); + + if(count > 0) + body.Reserve(count); + + if(method == METHOD_HEAD) + phase = FINISHED; + else + if(header["transfer-encoding"] == "chunked") { + count = 0; + StartPhase(CHUNK_HEADER); + } + else + StartPhase(BODY); + body.Clear(); + bodylen = 0; + gzip = GetHeader("content-encoding") == "gzip"; + if(gzip) { + gzip = true; + z.WhenOut = callback(this, &HttpRequest::Out); + z.ChunkSize(chunk).GZip().Decompress(); + } +} + +void HttpRequest::ContentOut(const void *ptr, dword size) +{ + body.Cat((const char *)ptr, size); +} + +void HttpRequest::Out(const void *ptr, dword size) +{ + LLOG("HTTP Out " << size); + if(z.IsError()) + HttpError("gzip format error"); + int64 l = bodylen + size; + if(l > max_content_size) { + HttpError("content length exceeded " + AsString(max_content_size)); + phase = FAILED; + return; + } + WhenContent(ptr, size); + bodylen += size; +} + +bool HttpRequest::ReadingBody() +{ + LLOG("HTTP reading data " << count); + int n = chunk; + if(count >= 0) + n = min(n, count); + String s = Get(n); + if(s.GetCount() == 0) + return !IsEof() && count; +#ifndef ENDZIP + if(gzip) + z.Put(~s, s.GetCount()); + else +#endif + Out(~s, s.GetCount()); + if(count >= 0) { + count -= s.GetCount(); + return !IsEof() && count > 0; + } + return !IsEof(); +} + +void HttpRequest::CopyCookies() +{ + int q = header.fields.Find("set-cookie"); + while(q >= 0) { + Cookie(header.fields[q]); + q = header.fields.FindNext(q); + } +} + +void HttpRequest::Finish() +{ + if(gzip) { + #ifdef ENDZIP + body = GZDecompress(body); + if(body.IsVoid()) { + HttpError("gzip decompress at finish error"); + phase = FAILED; + return; + } + #else + z.End(); + if(z.IsError()) { + HttpError("gzip format error (finish)"); + phase = FAILED; + return; + } + #endif + } + Close(); + if(status_code == 401 && !IsNull(username)) { + String authenticate = header["www-authenticate"]; + if(authenticate.GetCount() && redirect_count++ < max_redirects) { + LLOG("HTTP auth digest"); + CopyCookies(); + Digest(CalculateDigest(authenticate)); + Start(); + return; + } + } + if(status_code >= 300 && status_code < 400) { + String url = GetRedirectUrl(); + if(url.GetCount() && redirect_count++ < max_redirects) { + LLOG("HTTP redirect " << url); + Url(url); + CopyCookies(); + Start(); + retry_count = 0; + return; + } + } + phase = FINISHED; +} + +String HttpRequest::Execute() +{ + while(Do()) + LLOG("HTTP Execute: " << GetPhaseName()); + return IsSuccess() ? GetContent() : String::GetVoid(); +} + +String HttpRequest::GetPhaseName() const +{ + static const char *m[] = { + "Initial state", + "Start", + "Resolving host name", + "SSL proxy request", + "SSL proxy response", + "SSL handshake", + "Sending request", + "Receiving header", + "Receiving content", + "Receiving chunk header", + "Receiving content chunk", + "Receiving trailer", + "Finished", + "Failed", + }; + return phase >= 0 && phase <= FAILED ? m[phase] : ""; +} + +END_UPP_NAMESPACE diff --git a/uppsrc/Core/Socket.cpp b/uppsrc/Core/Socket.cpp index e0be57c20..494617708 100644 --- a/uppsrc/Core/Socket.cpp +++ b/uppsrc/Core/Socket.cpp @@ -1,893 +1,892 @@ -#include "Core.h" - -#ifdef PLATFORM_WIN32 -#include - #ifdef COMPILER_MSC - #include - #endif -#include -#endif - -#ifdef PLATFORM_POSIX -#include -#endif - -NAMESPACE_UPP - -#ifdef PLATFORM_WIN32 -#pragma comment(lib, "ws2_32.lib") -#endif - -#define LLOG(x) // DLOG("TCP " << x) - -IpAddrInfo::Entry IpAddrInfo::pool[COUNT]; - -AuxMutex IpAddrInfoPoolMutex; - -void IpAddrInfo::EnterPool() -{ - IpAddrInfoPoolMutex.Enter(); -} - -void IpAddrInfo::LeavePool() -{ - IpAddrInfoPoolMutex.Leave(); -} - -int sGetAddrInfo(const char *host, const char *port, addrinfo **result) -{ - addrinfo hints; - memset(&hints, 0, sizeof(addrinfo)); - hints.ai_family = AF_UNSPEC; - hints.ai_socktype = SOCK_STREAM; - hints.ai_protocol = IPPROTO_TCP; - - return getaddrinfo(host, port, &hints, result); -} - -auxthread_t auxthread__ IpAddrInfo::Thread(void *ptr) -{ - Entry *entry = (Entry *)ptr; - EnterPool(); - if(entry->status == WORKING) { - char host[1025]; - char port[257]; - strcpy(host, entry->host); - strcpy(port, entry->port); - LeavePool(); - addrinfo *result; - if(sGetAddrInfo(host, port, &result) == 0 && result) { - EnterPool(); - if(entry->status == WORKING) { - entry->addr = result; - entry->status = RESOLVED; - } - else { - freeaddrinfo(result); - entry->status = EMPTY; - } - } - else { - EnterPool(); - if(entry->status == CANCELED) - entry->status = EMPTY; - else - entry->status = FAILED; - } - } - LeavePool(); - return 0; -} - -bool IpAddrInfo::Execute(const String& host, int port) -{ - Clear(); - entry = exe; - addrinfo *result; - entry->addr = sGetAddrInfo(~host, ~AsString(port), &result) == 0 ? result : NULL; - return entry->addr; -} - -void IpAddrInfo::Start() -{ - if(entry) - return; - EnterPool(); - for(int i = 0; i < COUNT; i++) { - Entry *e = pool + i; - if(e->status == EMPTY) { - entry = e; - e->addr = NULL; - if(host.GetCount() > 1024 || port.GetCount() > 256) - e->status = FAILED; - else { - e->status = WORKING; - e->host = host; - e->port = port; - StartAuxThread(&IpAddrInfo::Thread, e); - } - break; - } - } - LeavePool(); -} - -void IpAddrInfo::Start(const String& host_, int port_) -{ - Clear(); - port = AsString(port_); - host = host_; - Start(); -} - -bool IpAddrInfo::InProgress() -{ - if(!entry) { - Start(); - return true; - } - EnterPool(); - int s = entry->status; - LeavePool(); - return s == WORKING; -} - -addrinfo *IpAddrInfo::GetResult() -{ - EnterPool(); - addrinfo *ai = entry ? entry->addr : NULL; - LeavePool(); - return ai; -} - -void IpAddrInfo::Clear() -{ - EnterPool(); - if(entry) { - if(entry->status == RESOLVED && entry->addr) - freeaddrinfo(entry->addr); - if(entry->status == WORKING) - entry->status = CANCELED; - else - entry->status = EMPTY; - entry = NULL; - } - LeavePool(); -} - -IpAddrInfo::IpAddrInfo() -{ - TcpSocket::Init(); - entry = NULL; -} - -#ifdef PLATFORM_POSIX - -#define SOCKERR(x) x - -const char *TcpSocketErrorDesc(int code) -{ - return strerror(code); -} - -int TcpSocket::GetErrorCode() -{ - return errno; -} - -#else - -#define SOCKERR(x) WSA##x - -const char *TcpSocketErrorDesc(int code) -{ - static Tuple2 err[] = { - { WSAEINTR, "Interrupted function call." }, - { WSAEACCES, "Permission denied." }, - { WSAEFAULT, "Bad address." }, - { WSAEINVAL, "Invalid argument." }, - { WSAEMFILE, "Too many open files." }, - { WSAEWOULDBLOCK, "Resource temporarily unavailable." }, - { WSAEINPROGRESS, "Operation now in progress." }, - { WSAEALREADY, "Operation already in progress." }, - { WSAENOTSOCK, "TcpSocket operation on nonsocket." }, - { WSAEDESTADDRREQ, "Destination address required." }, - { WSAEMSGSIZE, "Message too long." }, - { WSAEPROTOTYPE, "Protocol wrong type for socket." }, - { WSAENOPROTOOPT, "Bad protocol option." }, - { WSAEPROTONOSUPPORT, "Protocol not supported." }, - { WSAESOCKTNOSUPPORT, "TcpSocket type not supported." }, - { WSAEOPNOTSUPP, "Operation not supported." }, - { WSAEPFNOSUPPORT, "Protocol family not supported." }, - { WSAEAFNOSUPPORT, "Address family not supported by protocol family." }, - { WSAEADDRINUSE, "Address already in use." }, - { WSAEADDRNOTAVAIL, "Cannot assign requested address." }, - { WSAENETDOWN, "Network is down." }, - { WSAENETUNREACH, "Network is unreachable." }, - { WSAENETRESET, "Network dropped connection on reset." }, - { WSAECONNABORTED, "Software caused connection abort." }, - { WSAECONNRESET, "Connection reset by peer." }, - { WSAENOBUFS, "No buffer space available." }, - { WSAEISCONN, "TcpSocket is already connected." }, - { WSAENOTCONN, "TcpSocket is not connected." }, - { WSAESHUTDOWN, "Cannot send after socket shutdown." }, - { WSAETIMEDOUT, "Connection timed out." }, - { WSAECONNREFUSED, "Connection refused." }, - { WSAEHOSTDOWN, "Host is down." }, - { WSAEHOSTUNREACH, "No route to host." }, - { WSAEPROCLIM, "Too many processes." }, - { WSASYSNOTREADY, "Network subsystem is unavailable." }, - { WSAVERNOTSUPPORTED, "Winsock.dll version out of range." }, - { WSANOTINITIALISED, "Successful WSAStartup not yet performed." }, - { WSAEDISCON, "Graceful shutdown in progress." }, - { WSATYPE_NOT_FOUND, "Class type not found." }, - { WSAHOST_NOT_FOUND, "Host not found." }, - { WSATRY_AGAIN, "Nonauthoritative host not found." }, - { WSANO_RECOVERY, "This is a nonrecoverable error." }, - { WSANO_DATA, "Valid name, no data record of requested type." }, - { WSASYSCALLFAILURE, "System call failure." }, - }; - const Tuple2 *x = FindTuple(err, __countof(err), code); - return x ? x->b : "Unknown error code."; -} - -int TcpSocket::GetErrorCode() -{ - return WSAGetLastError(); -} - -#endif - -void TcpSocketInit() -{ -#if defined(PLATFORM_WIN32) - ONCELOCK { - WSADATA wsadata; - WSAStartup(MAKEWORD(2, 2), &wsadata); - } -#endif -} - -void TcpSocket::Init() -{ - TcpSocketInit(); -} - -void TcpSocket::Reset() -{ - is_eof = false; - socket = INVALID_SOCKET; - ipv6 = false; - ptr = end = buffer; - is_error = false; - is_abort = false; - mode = NONE; - ssl.Clear(); - sslinfo.Clear(); -} - -TcpSocket::TcpSocket() -{ - ClearError(); - Reset(); - timeout = Null; - waitstep = 20; - asn1 = false; -} - -bool TcpSocket::Open(int family, int type, int protocol) -{ - Init(); - Close(); - ClearError(); - if((socket = ::socket(family, type, protocol)) == INVALID_SOCKET) - return false; - LLOG("TcpSocket::Data::Open() -> " << (int)socket); -#ifdef PLATFORM_WIN32 - u_long arg = 1; - if(ioctlsocket(socket, FIONBIO, &arg)) - SetSockError("ioctlsocket(FIO[N]BIO)"); -#else - if(fcntl(socket, F_SETFL, (fcntl(socket, F_GETFL, 0) | O_NONBLOCK))) - SetSockError("fcntl(O_[NON]BLOCK)"); -#endif - return true; -} - -bool TcpSocket::Listen(int port, int listen_count, bool ipv6_, bool reuse) -{ - Close(); - Init(); - Reset(); - - ipv6 = ipv6_; - if(!Open(ipv6 ? AF_INET6 : AF_INET, SOCK_STREAM, 0)) - return false; - sockaddr_in sin; -#ifdef PLATFORM_WIN32 - SOCKADDR_IN6 sin6; - if(ipv6 && IsWinVista()) -#else - sockaddr_in6 sin6; - if(ipv6) -#endif - { - Zero(sin6); - sin.sin_family = AF_INET6; - sin.sin_port = htons(port); - sin.sin_addr.s_addr = htonl(INADDR_ANY); - } - else { - Zero(sin); - sin.sin_family = AF_INET; - sin.sin_port = htons(port); - sin.sin_addr.s_addr = htonl(INADDR_ANY); - } - if(reuse) { - int optval = 1; - setsockopt(socket, SOL_SOCKET, SO_REUSEADDR, (const char *)&optval, sizeof(optval)); - } - if(bind(socket, ipv6 ? (const sockaddr *)&sin6 : (const sockaddr *)&sin, - ipv6 ? sizeof(sin6) : sizeof(sin))) { - SetSockError(Format("bind(port=%d)", port)); - return false; - } - if(listen(socket, listen_count)) { - SetSockError(Format("listen(port=%d, count=%d)", port, listen_count)); - return false; - } - return true; -} - -bool TcpSocket::Accept(TcpSocket& ls) -{ - Close(); - Init(); - Reset(); - ASSERT(ls.IsOpen()); - if(timeout) { - int h = ls.GetTimeout(); - bool b = ls.Timeout(timeout).Wait(WAIT_READ, GetEndTime()); - ls.Timeout(h); - if(!b) - return false; - } - if(!Open(ls.ipv6 ? AF_INET6 : AF_INET, SOCK_STREAM, 0)) - return false; - socket = accept(ls.GetSOCKET(), NULL, NULL); - if(socket == INVALID_SOCKET) { - SetSockError("accept"); - return false; - } - mode = ACCEPT; - return true; -} - -String TcpSocket::GetPeerAddr() const -{ - if(!IsOpen()) - return Null; - sockaddr_in addr; - socklen_t l = sizeof(addr); - if(getpeername(socket, (sockaddr *)&addr, &l) != 0) - return Null; - if(l > sizeof(addr)) - return Null; -#ifdef PLATFORM_WIN32 - return inet_ntoa(addr.sin_addr); -#else - char h[200]; - return inet_ntop(AF_INET, &addr.sin_addr, h, 200); -#endif -} - -void TcpSocket::NoDelay() -{ - ASSERT(IsOpen()); - int __true = 1; - LLOG("NoDelay(" << (int)socket << ")"); - if(setsockopt(socket, IPPROTO_TCP, TCP_NODELAY, (const char *)&__true, sizeof(__true))) - SetSockError("setsockopt(TCP_NODELAY)"); -} - -void TcpSocket::Linger(int msecs) -{ - ASSERT(IsOpen()); - linger ls; - ls.l_onoff = !IsNull(msecs) ? 1 : 0; - ls.l_linger = !IsNull(msecs) ? (msecs + 999) / 1000 : 0; - if(setsockopt(socket, SOL_SOCKET, SO_LINGER, - reinterpret_cast(&ls), sizeof(ls))) - SetSockError("setsockopt(SO_LINGER)"); -} - -void TcpSocket::Attach(SOCKET s) -{ - Close(); - socket = s; -} - -bool TcpSocket::RawConnect(addrinfo *rp) -{ - if(!rp) { - SetSockError("connect", -1, "not found"); - return false; - } - for(;;) { - if(rp && Open(rp->ai_family, rp->ai_socktype, rp->ai_protocol)) { - if(connect(socket, rp->ai_addr, rp->ai_addrlen) == 0 || - GetErrorCode() == SOCKERR(EINPROGRESS) || GetErrorCode() == SOCKERR(EWOULDBLOCK)) - break; - Close(); - } - rp = rp->ai_next; - if(!rp) { - SetSockError("connect", -1, "failed"); - return false; - } - } - mode = CONNECT; - return true; -} - - -bool TcpSocket::Connect(IpAddrInfo& info) -{ - LLOG("TCP Connect addrinfo"); - Init(); - Reset(); - addrinfo *result = info.GetResult(); - return RawConnect(result); -} - -bool TcpSocket::Connect(const char *host, int port) -{ - LLOG("TCP Connect(" << host << ':' << port << ')'); - - Init(); - Reset(); - IpAddrInfo info; - if(!info.Execute(host, port)) { - SetSockError(Format("getaddrinfo(%s) failed", host)); - return false; - } - return Connect(info); -} - -void TcpSocket::RawClose() -{ - LLOG("TCP close " << (int)socket); - if(socket != INVALID_SOCKET) { - int res; -#if defined(PLATFORM_WIN32) - res = closesocket(socket); -#elif defined(PLATFORM_POSIX) - res = close(socket); -#else - #error Unsupported platform -#endif - if(res && !IsError()) - SetSockError("close"); - socket = INVALID_SOCKET; - } -} - -void TcpSocket::Close() -{ - if(ssl) - ssl->Close(); - else - RawClose(); - ssl.Clear(); -} - -bool TcpSocket::WouldBlock() -{ - int c = GetErrorCode(); -#ifdef PLATFORM_POSIX - return c == SOCKERR(EWOULDBLOCK) || c == SOCKERR(EAGAIN); -#endif -#ifdef PLATFORM_WIN32 - return c == SOCKERR(EWOULDBLOCK) || c == SOCKERR(ENOTCONN); -#endif -} - -int TcpSocket::RawSend(const void *buf, int amount) -{ - int res = send(socket, (const char *)buf, amount, 0); - if(res < 0 && WouldBlock()) - res = 0; - else - if(res == 0 || res < 0) - SetSockError("send"); - return res; -} - -int TcpSocket::Send(const void *buf, int amount) -{ - if(SSLHandshake()) - return 0; - return ssl ? ssl->Send(buf, amount) : RawSend(buf, amount); -} - -void TcpSocket::Shutdown() -{ - ASSERT(IsOpen()); - if(shutdown(socket, SD_SEND)) - SetSockError("shutdown(SD_SEND)"); -} - -String TcpSocket::GetHostName() -{ - Init(); - char buffer[256]; - gethostname(buffer, __countof(buffer)); - return buffer; -} - -bool TcpSocket::RawWait(dword flags, int end_time) -{ - LLOG("Wait(" << msecs() << " - " << end_time << ", " << flags << ")"); - if((flags & WAIT_READ) && ptr != end) - return true; - if(socket == INVALID_SOCKET) - return false; - for(;;) { - if(IsError() || IsAbort()) - return false; - int to = end_time - msecs(); - if(WhenWait) - to = waitstep; - timeval *tvalp = NULL; - timeval tval; - if(!IsNull(timeout) || WhenWait) { - to = max(to, 0); - tval.tv_sec = to / 1000; - tval.tv_usec = 1000 * (to % 1000); - tvalp = &tval; - } - fd_set fdset[1]; - FD_ZERO(fdset); - FD_SET(socket, fdset); - int avail = select((int)socket + 1, - flags & WAIT_READ ? fdset : NULL, - flags & WAIT_WRITE ? fdset : NULL, - flags & WAIT_EXCEPTION ? fdset : NULL, tvalp); - LLOG("Wait select avail: " << avail); - if(avail < 0) { - SetSockError("wait"); - return false; - } - if(avail > 0) - return true; - if(to <= 0 && timeout) { - return false; - } - WhenWait(); - if(timeout == 0) - return false; - } -} - -bool TcpSocket::Wait(dword flags, int end_time) -{ - return ssl ? ssl->Wait(flags, end_time) : RawWait(flags, end_time); -} - -int TcpSocket::GetEndTime() const -{ - return IsNull(timeout) ? INT_MAX : msecs() + timeout; -} - -bool TcpSocket::Wait(dword flags) -{ - return Wait(flags, GetEndTime()); -} - -int TcpSocket::Put(const char *s, int length) -{ - LLOG("Put " << socket << ": " << length); - ASSERT(IsOpen()); - if(length < 0 && s) - length = (int)strlen(s); - if(!s || length <= 0 || IsError() || IsAbort()) - return 0; - done = 0; - bool peek = false; - int end_time = GetEndTime(); - while(done < length) { - if(peek && !Wait(WAIT_WRITE, end_time)) - return done; - peek = false; - int count = Send(s + done, length - done); - if(IsError() || timeout == 0 && count == 0 && peek) - return done; - if(count > 0) - done += count; - else - peek = true; - } - LLOG("//Put() -> " << done); - return done; -} - -bool TcpSocket::PutAll(const char *s, int len) -{ - if(Put(s, len) != len) { - if(!IsError()) - SetSockError("GePutAll", -1, "timeout"); - return false; - } - return true; -} - -bool TcpSocket::PutAll(const String& s) -{ - if(Put(s) != s.GetCount()) { - if(!IsError()) - SetSockError("GePutAll", -1, "timeout"); - return false; - } - return true; -} - -int TcpSocket::RawRecv(void *buf, int amount) -{ - int res = recv(socket, (char *)buf, amount, 0); - if(res == 0) - is_eof = true; - else - if(res < 0 && WouldBlock()) - res = 0; - else - if(res < 0) - SetSockError("recv"); - LLOG("recv(" << socket << "): " << res << " bytes: " - << AsCString((char *)buf, (char *)buf + min(res, 16)) - << (res ? "" : IsEof() ? ", EOF" : ", WOULDBLOCK")); - return res; -} - -int TcpSocket::Recv(void *buffer, int maxlen) -{ - if(SSLHandshake()) - return 0; - return ssl ? ssl->Recv(buffer, maxlen) : RawRecv(buffer, maxlen); -} - -void TcpSocket::ReadBuffer(int end_time) -{ - ptr = end = buffer; - if(Wait(WAIT_READ, end_time)) - end = buffer + Recv(buffer, BUFFERSIZE); -} - -int TcpSocket::Get_() -{ - if(!IsOpen() || IsError() || IsEof() || IsAbort()) - return -1; - ReadBuffer(GetEndTime()); - return ptr < end ? *ptr++ : -1; -} - -int TcpSocket::Peek_(int end_time) -{ - if(!IsOpen() || IsError() || IsEof() || IsAbort()) - return -1; - ReadBuffer(end_time); - return ptr < end ? *ptr : -1; -} - -int TcpSocket::Peek_() -{ - return Peek_(GetEndTime()); -} - -int TcpSocket::Get(void *buffer, int count) -{ - LLOG("Get " << count); - - if(!IsOpen() || IsError() || IsEof() || IsAbort()) - return 0; - - String out; - int l = end - ptr; - done = 0; - if(l > 0) - if(l < count) { - memcpy(buffer, ptr, l); - done += l; - ptr = end; - } - else { - memcpy(buffer, ptr, count); - ptr += count; - return count; - } - int end_time = GetEndTime(); - while(done < count && !IsError() && !IsEof()) { - if(!Wait(WAIT_READ, end_time)) - break; - int part = Recv((char *)buffer + done, count - done); - if(part > 0) - done += part; - if(timeout == 0) - break; - } - return done; -} - -String TcpSocket::Get(int count) -{ - if(count == 0) - return Null; - StringBuffer out(count); - int done = Get(out, count); - if(!done && IsEof()) - return String::GetVoid(); - out.SetLength(done); - return out; -} - -bool TcpSocket::GetAll(void *buffer, int len) -{ - if(Get(buffer, len) == len) - return true; - if(!IsError()) - SetSockError("GetAll", -1, "timeout"); - return false; -} - -String TcpSocket::GetAll(int len) -{ - String s = Get(len); - if(s.GetCount() != len) { - if(!IsError()) - SetSockError("GetAll", -1, "timeout"); - return String::GetVoid(); - } - return s; -} - -String TcpSocket::GetLine(int maxlen) -{ - String ln; - int end_time = GetEndTime(); - for(;;) { - int c = Peek(end_time); - if(c < 0) { - if(!IsError()) - SetSockError("GetLine", -1, "timeout"); - return String::GetVoid(); - } - Get(); - if(c == '\n') - return ln; - if(ln.GetCount() >= maxlen) { - if(!IsError()) - SetSockError("GetLine", -1, "maximal length exceeded"); - return String::GetVoid(); - } - if(c != '\r') - ln.Cat(c); - } -} - -void TcpSocket::SetSockError(const char *context, int code, const char *errdesc) -{ - errorcode = code; - errordesc.Clear(); - if(socket != INVALID_SOCKET) - errordesc << "socket(" << (int)socket << ") / "; - errordesc << context << ": " << errdesc; - is_error = true; - LLOG("TCP ERROR " << errordesc); -} - -void TcpSocket::SetSockError(const char *context, const char *errdesc) -{ - SetSockError(context, GetErrorCode(), errdesc); -} - -void TcpSocket::SetSockError(const char *context) -{ - SetSockError(context, TcpSocketErrorDesc(GetErrorCode())); -} - -TcpSocket::SSL *(*TcpSocket::CreateSSL)(TcpSocket& socket); - -bool TcpSocket::StartSSL() -{ - ASSERT(IsOpen()); - if(!CreateSSL) { - SetSockError("StartSSL", -1, "Missing SSL support (Core/SSL)"); - return false; - } - if(!IsOpen()) { - SetSockError("StartSSL", -1, "Socket is not open"); - return false; - } - if(mode != CONNECT && mode != ACCEPT) { - SetSockError("StartSSL", -1, "Socket is not connected"); - return false; - } - ssl = (*CreateSSL)(*this); - if(!ssl->Start()) { - ssl.Clear(); - return false; - } - SSLHandshake(); - return true; -} - -bool TcpSocket::SSLHandshake() -{ - if(ssl && (mode == CONNECT || mode == ACCEPT)) { - dword w = ssl->Handshake(); - if(w) { - Wait(w); - return ssl->Handshake(); - } - } - return false; -} - -void TcpSocket::SSLCertificate(const String& cert_, const String& pkey_, bool asn1_) -{ - cert = cert_; - pkey = pkey_; - asn1 = asn1_; -} - -int SocketWaitEvent::Wait(int timeout) -{ - FD_ZERO(read); - FD_ZERO(write); - FD_ZERO(exception); - int maxindex = -1; - for(int i = 0; i < socket.GetCount(); i++) { - const Tuple2& s = socket[i]; - if(s.a >= 0) { - const Tuple2& s = socket[i]; - if(s.b & WAIT_READ) - FD_SET(s.a, read); - if(s.b & WAIT_WRITE) - FD_SET(s.a, write); - if(s.b & WAIT_EXCEPTION) - FD_SET(s.a, exception); - maxindex = max(s.a, maxindex); - } - } - timeval *tvalp = NULL; - timeval tval; - if(!IsNull(timeout)) { - tval.tv_sec = timeout / 1000; - tval.tv_usec = 1000 * (timeout % 1000); - tvalp = &tval; - } - return select(maxindex + 1, read, write, exception, tvalp); -} - -dword SocketWaitEvent::Get(int i) const -{ - int s = socket[i].a; - if(s < 0) - return 0; - dword events = 0; - if(FD_ISSET(s, read)) - events |= WAIT_READ; - if(FD_ISSET(s, write)) - events |= WAIT_WRITE; - if(FD_ISSET(s, exception)) - events |= WAIT_EXCEPTION; - return events; -} - -SocketWaitEvent::SocketWaitEvent() -{ - FD_ZERO(read); - FD_ZERO(write); - FD_ZERO(exception); -} - -END_UPP_NAMESPACE +#include "Core.h" + +#ifdef PLATFORM_WIN32 +#include + #ifdef COMPILER_MSC + #include + #endif +#include +#endif + +#ifdef PLATFORM_POSIX +#include +#endif + +NAMESPACE_UPP + +#ifdef PLATFORM_WIN32 +#pragma comment(lib, "ws2_32.lib") +#endif + +#define LLOG(x) // DLOG("TCP " << x) + +IpAddrInfo::Entry IpAddrInfo::pool[COUNT]; + +AuxMutex IpAddrInfoPoolMutex; + +void IpAddrInfo::EnterPool() +{ + IpAddrInfoPoolMutex.Enter(); +} + +void IpAddrInfo::LeavePool() +{ + IpAddrInfoPoolMutex.Leave(); +} + +int sGetAddrInfo(const char *host, const char *port, addrinfo **result) +{ + addrinfo hints; + memset(&hints, 0, sizeof(addrinfo)); + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + hints.ai_protocol = IPPROTO_TCP; + + return getaddrinfo(host, port, &hints, result); +} + +auxthread_t auxthread__ IpAddrInfo::Thread(void *ptr) +{ + Entry *entry = (Entry *)ptr; + EnterPool(); + if(entry->status == WORKING) { + char host[1025]; + char port[257]; + strcpy(host, entry->host); + strcpy(port, entry->port); + LeavePool(); + addrinfo *result; + if(sGetAddrInfo(host, port, &result) == 0 && result) { + EnterPool(); + if(entry->status == WORKING) { + entry->addr = result; + entry->status = RESOLVED; + } + else { + freeaddrinfo(result); + entry->status = EMPTY; + } + } + else { + EnterPool(); + if(entry->status == CANCELED) + entry->status = EMPTY; + else + entry->status = FAILED; + } + } + LeavePool(); + return 0; +} + +bool IpAddrInfo::Execute(const String& host, int port) +{ + Clear(); + entry = exe; + addrinfo *result; + entry->addr = sGetAddrInfo(~host, ~AsString(port), &result) == 0 ? result : NULL; + return entry->addr; +} + +void IpAddrInfo::Start() +{ + if(entry) + return; + EnterPool(); + for(int i = 0; i < COUNT; i++) { + Entry *e = pool + i; + if(e->status == EMPTY) { + entry = e; + e->addr = NULL; + if(host.GetCount() > 1024 || port.GetCount() > 256) + e->status = FAILED; + else { + e->status = WORKING; + e->host = host; + e->port = port; + StartAuxThread(&IpAddrInfo::Thread, e); + } + break; + } + } + LeavePool(); +} + +void IpAddrInfo::Start(const String& host_, int port_) +{ + Clear(); + port = AsString(port_); + host = host_; + Start(); +} + +bool IpAddrInfo::InProgress() +{ + if(!entry) { + Start(); + return true; + } + EnterPool(); + int s = entry->status; + LeavePool(); + return s == WORKING; +} + +addrinfo *IpAddrInfo::GetResult() +{ + EnterPool(); + addrinfo *ai = entry ? entry->addr : NULL; + LeavePool(); + return ai; +} + +void IpAddrInfo::Clear() +{ + EnterPool(); + if(entry) { + if(entry->status == RESOLVED && entry->addr) + freeaddrinfo(entry->addr); + if(entry->status == WORKING) + entry->status = CANCELED; + else + entry->status = EMPTY; + entry = NULL; + } + LeavePool(); +} + +IpAddrInfo::IpAddrInfo() +{ + TcpSocket::Init(); + entry = NULL; +} + +#ifdef PLATFORM_POSIX + +#define SOCKERR(x) x + +const char *TcpSocketErrorDesc(int code) +{ + return strerror(code); +} + +int TcpSocket::GetErrorCode() +{ + return errno; +} + +#else + +#define SOCKERR(x) WSA##x + +const char *TcpSocketErrorDesc(int code) +{ + static Tuple2 err[] = { + { WSAEINTR, "Interrupted function call." }, + { WSAEACCES, "Permission denied." }, + { WSAEFAULT, "Bad address." }, + { WSAEINVAL, "Invalid argument." }, + { WSAEMFILE, "Too many open files." }, + { WSAEWOULDBLOCK, "Resource temporarily unavailable." }, + { WSAEINPROGRESS, "Operation now in progress." }, + { WSAEALREADY, "Operation already in progress." }, + { WSAENOTSOCK, "TcpSocket operation on nonsocket." }, + { WSAEDESTADDRREQ, "Destination address required." }, + { WSAEMSGSIZE, "Message too long." }, + { WSAEPROTOTYPE, "Protocol wrong type for socket." }, + { WSAENOPROTOOPT, "Bad protocol option." }, + { WSAEPROTONOSUPPORT, "Protocol not supported." }, + { WSAESOCKTNOSUPPORT, "TcpSocket type not supported." }, + { WSAEOPNOTSUPP, "Operation not supported." }, + { WSAEPFNOSUPPORT, "Protocol family not supported." }, + { WSAEAFNOSUPPORT, "Address family not supported by protocol family." }, + { WSAEADDRINUSE, "Address already in use." }, + { WSAEADDRNOTAVAIL, "Cannot assign requested address." }, + { WSAENETDOWN, "Network is down." }, + { WSAENETUNREACH, "Network is unreachable." }, + { WSAENETRESET, "Network dropped connection on reset." }, + { WSAECONNABORTED, "Software caused connection abort." }, + { WSAECONNRESET, "Connection reset by peer." }, + { WSAENOBUFS, "No buffer space available." }, + { WSAEISCONN, "TcpSocket is already connected." }, + { WSAENOTCONN, "TcpSocket is not connected." }, + { WSAESHUTDOWN, "Cannot send after socket shutdown." }, + { WSAETIMEDOUT, "Connection timed out." }, + { WSAECONNREFUSED, "Connection refused." }, + { WSAEHOSTDOWN, "Host is down." }, + { WSAEHOSTUNREACH, "No route to host." }, + { WSAEPROCLIM, "Too many processes." }, + { WSASYSNOTREADY, "Network subsystem is unavailable." }, + { WSAVERNOTSUPPORTED, "Winsock.dll version out of range." }, + { WSANOTINITIALISED, "Successful WSAStartup not yet performed." }, + { WSAEDISCON, "Graceful shutdown in progress." }, + { WSATYPE_NOT_FOUND, "Class type not found." }, + { WSAHOST_NOT_FOUND, "Host not found." }, + { WSATRY_AGAIN, "Nonauthoritative host not found." }, + { WSANO_RECOVERY, "This is a nonrecoverable error." }, + { WSANO_DATA, "Valid name, no data record of requested type." }, + { WSASYSCALLFAILURE, "System call failure." }, + }; + const Tuple2 *x = FindTuple(err, __countof(err), code); + return x ? x->b : "Unknown error code."; +} + +int TcpSocket::GetErrorCode() +{ + return WSAGetLastError(); +} + +#endif + +void TcpSocketInit() +{ +#if defined(PLATFORM_WIN32) + ONCELOCK { + WSADATA wsadata; + WSAStartup(MAKEWORD(2, 2), &wsadata); + } +#endif +} + +void TcpSocket::Init() +{ + TcpSocketInit(); +} + +void TcpSocket::Reset() +{ + is_eof = false; + socket = INVALID_SOCKET; + ipv6 = false; + ptr = end = buffer; + is_error = false; + is_abort = false; + mode = NONE; + ssl.Clear(); + sslinfo.Clear(); +} + +TcpSocket::TcpSocket() +{ + ClearError(); + Reset(); + timeout = Null; + waitstep = 20; + asn1 = false; +} + +bool TcpSocket::Open(int family, int type, int protocol) +{ + Init(); + Close(); + ClearError(); + if((socket = ::socket(family, type, protocol)) == INVALID_SOCKET) + return false; + LLOG("TcpSocket::Data::Open() -> " << (int)socket); +#ifdef PLATFORM_WIN32 + u_long arg = 1; + if(ioctlsocket(socket, FIONBIO, &arg)) + SetSockError("ioctlsocket(FIO[N]BIO)"); +#else + if(fcntl(socket, F_SETFL, (fcntl(socket, F_GETFL, 0) | O_NONBLOCK))) + SetSockError("fcntl(O_[NON]BLOCK)"); +#endif + return true; +} + +bool TcpSocket::Listen(int port, int listen_count, bool ipv6_, bool reuse) +{ + Close(); + Init(); + Reset(); + + ipv6 = ipv6_; + if(!Open(ipv6 ? AF_INET6 : AF_INET, SOCK_STREAM, 0)) + return false; + sockaddr_in sin; +#ifdef PLATFORM_WIN32 + SOCKADDR_IN6 sin6; + if(ipv6 && IsWinVista()) +#else + sockaddr_in6 sin6; + if(ipv6) +#endif + { + Zero(sin6); + sin.sin_family = AF_INET6; + sin.sin_port = htons(port); + sin.sin_addr.s_addr = htonl(INADDR_ANY); + } + else { + Zero(sin); + sin.sin_family = AF_INET; + sin.sin_port = htons(port); + sin.sin_addr.s_addr = htonl(INADDR_ANY); + } + if(reuse) { + int optval = 1; + setsockopt(socket, SOL_SOCKET, SO_REUSEADDR, (const char *)&optval, sizeof(optval)); + } + if(bind(socket, ipv6 ? (const sockaddr *)&sin6 : (const sockaddr *)&sin, + ipv6 ? sizeof(sin6) : sizeof(sin))) { + SetSockError(Format("bind(port=%d)", port)); + return false; + } + if(listen(socket, listen_count)) { + SetSockError(Format("listen(port=%d, count=%d)", port, listen_count)); + return false; + } + return true; +} + +bool TcpSocket::Accept(TcpSocket& ls) +{ + Close(); + Init(); + Reset(); + ASSERT(ls.IsOpen()); + if(timeout) { + int h = ls.GetTimeout(); + bool b = ls.Timeout(timeout).Wait(WAIT_READ, GetEndTime()); + ls.Timeout(h); + if(!b) + return false; + } + if(!Open(ls.ipv6 ? AF_INET6 : AF_INET, SOCK_STREAM, 0)) + return false; + socket = accept(ls.GetSOCKET(), NULL, NULL); + if(socket == INVALID_SOCKET) { + SetSockError("accept"); + return false; + } + mode = ACCEPT; + return true; +} + +String TcpSocket::GetPeerAddr() const +{ + if(!IsOpen()) + return Null; + sockaddr_in addr; + socklen_t l = sizeof(addr); + if(getpeername(socket, (sockaddr *)&addr, &l) != 0) + return Null; + if(l > sizeof(addr)) + return Null; +#ifdef PLATFORM_WIN32 + return inet_ntoa(addr.sin_addr); +#else + char h[200]; + return inet_ntop(AF_INET, &addr.sin_addr, h, 200); +#endif +} + +void TcpSocket::NoDelay() +{ + ASSERT(IsOpen()); + int __true = 1; + LLOG("NoDelay(" << (int)socket << ")"); + if(setsockopt(socket, IPPROTO_TCP, TCP_NODELAY, (const char *)&__true, sizeof(__true))) + SetSockError("setsockopt(TCP_NODELAY)"); +} + +void TcpSocket::Linger(int msecs) +{ + ASSERT(IsOpen()); + linger ls; + ls.l_onoff = !IsNull(msecs) ? 1 : 0; + ls.l_linger = !IsNull(msecs) ? (msecs + 999) / 1000 : 0; + if(setsockopt(socket, SOL_SOCKET, SO_LINGER, + reinterpret_cast(&ls), sizeof(ls))) + SetSockError("setsockopt(SO_LINGER)"); +} + +void TcpSocket::Attach(SOCKET s) +{ + Close(); + socket = s; +} + +bool TcpSocket::RawConnect(addrinfo *rp) +{ + if(!rp) { + SetSockError("connect", -1, "not found"); + return false; + } + for(;;) { + if(rp && Open(rp->ai_family, rp->ai_socktype, rp->ai_protocol)) { + if(connect(socket, rp->ai_addr, rp->ai_addrlen) == 0 || + GetErrorCode() == SOCKERR(EINPROGRESS) || GetErrorCode() == SOCKERR(EWOULDBLOCK)) + break; + Close(); + } + rp = rp->ai_next; + if(!rp) { + SetSockError("connect", -1, "failed"); + return false; + } + } + mode = CONNECT; + return true; +} + + +bool TcpSocket::Connect(IpAddrInfo& info) +{ + LLOG("TCP Connect addrinfo"); + Init(); + Reset(); + addrinfo *result = info.GetResult(); + return RawConnect(result); +} + +bool TcpSocket::Connect(const char *host, int port) +{ + LLOG("TCP Connect(" << host << ':' << port << ')'); + + Init(); + Reset(); + IpAddrInfo info; + if(!info.Execute(host, port)) { + SetSockError(Format("getaddrinfo(%s) failed", host)); + return false; + } + return Connect(info); +} + +void TcpSocket::RawClose() +{ + LLOG("TCP close " << (int)socket); + if(socket != INVALID_SOCKET) { + int res; +#if defined(PLATFORM_WIN32) + res = closesocket(socket); +#elif defined(PLATFORM_POSIX) + res = close(socket); +#else + #error Unsupported platform +#endif + if(res && !IsError()) + SetSockError("close"); + socket = INVALID_SOCKET; + } +} + +void TcpSocket::Close() +{ + if(ssl) + ssl->Close(); + else + RawClose(); + ssl.Clear(); +} + +bool TcpSocket::WouldBlock() +{ + int c = GetErrorCode(); +#ifdef PLATFORM_POSIX + return c == SOCKERR(EWOULDBLOCK) || c == SOCKERR(EAGAIN); +#endif +#ifdef PLATFORM_WIN32 + return c == SOCKERR(EWOULDBLOCK) || c == SOCKERR(ENOTCONN); +#endif +} + +int TcpSocket::RawSend(const void *buf, int amount) +{ + int res = send(socket, (const char *)buf, amount, 0); + if(res < 0 && WouldBlock()) + res = 0; + else + if(res == 0 || res < 0) + SetSockError("send"); + return res; +} + +int TcpSocket::Send(const void *buf, int amount) +{ + if(SSLHandshake()) + return 0; + return ssl ? ssl->Send(buf, amount) : RawSend(buf, amount); +} + +void TcpSocket::Shutdown() +{ + ASSERT(IsOpen()); + if(shutdown(socket, SD_SEND)) + SetSockError("shutdown(SD_SEND)"); +} + +String TcpSocket::GetHostName() +{ + Init(); + char buffer[256]; + gethostname(buffer, __countof(buffer)); + return buffer; +} + +bool TcpSocket::RawWait(dword flags, int end_time) +{ + LLOG("Wait(" << msecs() << " - " << end_time << ", " << flags << ")"); + if((flags & WAIT_READ) && ptr != end) + return true; + if(socket == INVALID_SOCKET) + return false; + for(;;) { + if(IsError() || IsAbort()) + return false; + int to = end_time - msecs(); + if(WhenWait) + to = waitstep; + timeval *tvalp = NULL; + timeval tval; + if(!IsNull(timeout) || WhenWait) { + to = max(to, 0); + tval.tv_sec = to / 1000; + tval.tv_usec = 1000 * (to % 1000); + tvalp = &tval; + } + fd_set fdset[1]; + FD_ZERO(fdset); + FD_SET(socket, fdset); + int avail = select((int)socket + 1, + flags & WAIT_READ ? fdset : NULL, + flags & WAIT_WRITE ? fdset : NULL, + flags & WAIT_EXCEPTION ? fdset : NULL, tvalp); + LLOG("Wait select avail: " << avail); + if(avail < 0) { + SetSockError("wait"); + return false; + } + if(avail > 0) + return true; + if(to <= 0 && timeout) + return false; + WhenWait(); + if(timeout == 0) + return false; + } +} + +bool TcpSocket::Wait(dword flags, int end_time) +{ + return ssl ? ssl->Wait(flags, end_time) : RawWait(flags, end_time); +} + +int TcpSocket::GetEndTime() const +{ + return IsNull(timeout) ? INT_MAX : msecs() + timeout; +} + +bool TcpSocket::Wait(dword flags) +{ + return Wait(flags, GetEndTime()); +} + +int TcpSocket::Put(const char *s, int length) +{ + LLOG("Put " << socket << ": " << length); + ASSERT(IsOpen()); + if(length < 0 && s) + length = (int)strlen(s); + if(!s || length <= 0 || IsError() || IsAbort()) + return 0; + done = 0; + bool peek = false; + int end_time = GetEndTime(); + while(done < length) { + if(peek && !Wait(WAIT_WRITE, end_time)) + return done; + peek = false; + int count = Send(s + done, length - done); + if(IsError() || timeout == 0 && count == 0 && peek) + return done; + if(count > 0) + done += count; + else + peek = true; + } + LLOG("//Put() -> " << done); + return done; +} + +bool TcpSocket::PutAll(const char *s, int len) +{ + if(Put(s, len) != len) { + if(!IsError()) + SetSockError("GePutAll", -1, "timeout"); + return false; + } + return true; +} + +bool TcpSocket::PutAll(const String& s) +{ + if(Put(s) != s.GetCount()) { + if(!IsError()) + SetSockError("GePutAll", -1, "timeout"); + return false; + } + return true; +} + +int TcpSocket::RawRecv(void *buf, int amount) +{ + int res = recv(socket, (char *)buf, amount, 0); + if(res == 0) + is_eof = true; + else + if(res < 0 && WouldBlock()) + res = 0; + else + if(res < 0) + SetSockError("recv"); + LLOG("recv(" << socket << "): " << res << " bytes: " + << AsCString((char *)buf, (char *)buf + min(res, 16)) + << (res ? "" : IsEof() ? ", EOF" : ", WOULDBLOCK")); + return res; +} + +int TcpSocket::Recv(void *buffer, int maxlen) +{ + if(SSLHandshake()) + return 0; + return ssl ? ssl->Recv(buffer, maxlen) : RawRecv(buffer, maxlen); +} + +void TcpSocket::ReadBuffer(int end_time) +{ + ptr = end = buffer; + if(Wait(WAIT_READ, end_time)) + end = buffer + Recv(buffer, BUFFERSIZE); +} + +int TcpSocket::Get_() +{ + if(!IsOpen() || IsError() || IsEof() || IsAbort()) + return -1; + ReadBuffer(GetEndTime()); + return ptr < end ? *ptr++ : -1; +} + +int TcpSocket::Peek_(int end_time) +{ + if(!IsOpen() || IsError() || IsEof() || IsAbort()) + return -1; + ReadBuffer(end_time); + return ptr < end ? *ptr : -1; +} + +int TcpSocket::Peek_() +{ + return Peek_(GetEndTime()); +} + +int TcpSocket::Get(void *buffer, int count) +{ + LLOG("Get " << count); + + if(!IsOpen() || IsError() || IsEof() || IsAbort()) + return 0; + + String out; + int l = end - ptr; + done = 0; + if(l > 0) + if(l < count) { + memcpy(buffer, ptr, l); + done += l; + ptr = end; + } + else { + memcpy(buffer, ptr, count); + ptr += count; + return count; + } + int end_time = GetEndTime(); + while(done < count && !IsError() && !IsEof()) { + if(!Wait(WAIT_READ, end_time)) + break; + int part = Recv((char *)buffer + done, count - done); + if(part > 0) + done += part; + if(timeout == 0) + break; + } + return done; +} + +String TcpSocket::Get(int count) +{ + if(count == 0) + return Null; + StringBuffer out(count); + int done = Get(out, count); + if(!done && IsEof()) + return String::GetVoid(); + out.SetLength(done); + return out; +} + +bool TcpSocket::GetAll(void *buffer, int len) +{ + if(Get(buffer, len) == len) + return true; + if(!IsError()) + SetSockError("GetAll", -1, "timeout"); + return false; +} + +String TcpSocket::GetAll(int len) +{ + String s = Get(len); + if(s.GetCount() != len) { + if(!IsError()) + SetSockError("GetAll", -1, "timeout"); + return String::GetVoid(); + } + return s; +} + +String TcpSocket::GetLine(int maxlen) +{ + String ln; + int end_time = GetEndTime(); + for(;;) { + int c = Peek(end_time); + if(c < 0) { + if(!IsError()) + SetSockError("GetLine", -1, "timeout"); + return String::GetVoid(); + } + Get(); + if(c == '\n') + return ln; + if(ln.GetCount() >= maxlen) { + if(!IsError()) + SetSockError("GetLine", -1, "maximal length exceeded"); + return String::GetVoid(); + } + if(c != '\r') + ln.Cat(c); + } +} + +void TcpSocket::SetSockError(const char *context, int code, const char *errdesc) +{ + errorcode = code; + errordesc.Clear(); + if(socket != INVALID_SOCKET) + errordesc << "socket(" << (int)socket << ") / "; + errordesc << context << ": " << errdesc; + is_error = true; + LLOG("TCP ERROR " << errordesc); +} + +void TcpSocket::SetSockError(const char *context, const char *errdesc) +{ + SetSockError(context, GetErrorCode(), errdesc); +} + +void TcpSocket::SetSockError(const char *context) +{ + SetSockError(context, TcpSocketErrorDesc(GetErrorCode())); +} + +TcpSocket::SSL *(*TcpSocket::CreateSSL)(TcpSocket& socket); + +bool TcpSocket::StartSSL() +{ + ASSERT(IsOpen()); + if(!CreateSSL) { + SetSockError("StartSSL", -1, "Missing SSL support (Core/SSL)"); + return false; + } + if(!IsOpen()) { + SetSockError("StartSSL", -1, "Socket is not open"); + return false; + } + if(mode != CONNECT && mode != ACCEPT) { + SetSockError("StartSSL", -1, "Socket is not connected"); + return false; + } + ssl = (*CreateSSL)(*this); + if(!ssl->Start()) { + ssl.Clear(); + return false; + } + SSLHandshake(); + return true; +} + +bool TcpSocket::SSLHandshake() +{ + if(ssl && (mode == CONNECT || mode == ACCEPT)) { + dword w = ssl->Handshake(); + if(w) { + Wait(w); + return ssl->Handshake(); + } + } + return false; +} + +void TcpSocket::SSLCertificate(const String& cert_, const String& pkey_, bool asn1_) +{ + cert = cert_; + pkey = pkey_; + asn1 = asn1_; +} + +int SocketWaitEvent::Wait(int timeout) +{ + FD_ZERO(read); + FD_ZERO(write); + FD_ZERO(exception); + int maxindex = -1; + for(int i = 0; i < socket.GetCount(); i++) { + const Tuple2& s = socket[i]; + if(s.a >= 0) { + const Tuple2& s = socket[i]; + if(s.b & WAIT_READ) + FD_SET(s.a, read); + if(s.b & WAIT_WRITE) + FD_SET(s.a, write); + if(s.b & WAIT_EXCEPTION) + FD_SET(s.a, exception); + maxindex = max(s.a, maxindex); + } + } + timeval *tvalp = NULL; + timeval tval; + if(!IsNull(timeout)) { + tval.tv_sec = timeout / 1000; + tval.tv_usec = 1000 * (timeout % 1000); + tvalp = &tval; + } + return select(maxindex + 1, read, write, exception, tvalp); +} + +dword SocketWaitEvent::Get(int i) const +{ + int s = socket[i].a; + if(s < 0) + return 0; + dword events = 0; + if(FD_ISSET(s, read)) + events |= WAIT_READ; + if(FD_ISSET(s, write)) + events |= WAIT_WRITE; + if(FD_ISSET(s, exception)) + events |= WAIT_EXCEPTION; + return events; +} + +SocketWaitEvent::SocketWaitEvent() +{ + FD_ZERO(read); + FD_ZERO(write); + FD_ZERO(exception); +} + +END_UPP_NAMESPACE diff --git a/uppsrc/Core/Util.cpp b/uppsrc/Core/Util.cpp index 46a4f13b3..ee26bde2c 100644 --- a/uppsrc/Core/Util.cpp +++ b/uppsrc/Core/Util.cpp @@ -128,7 +128,7 @@ dword GetTickCount() { #endif -int msecs(int from) { return int((GetTickCount() - (dword)from) & 0x7fffffff); } +int msecs(int from) { return GetTickCount() - (dword)from; } void TimeStop::Reset() { diff --git a/uppsrc/Web/SSL/httpscli.cpp b/uppsrc/Web/SSL/httpscli.cpp index 8c7e9a804..e63aa8bf9 100644 --- a/uppsrc/Web/SSL/httpscli.cpp +++ b/uppsrc/Web/SSL/httpscli.cpp @@ -1,128 +1,128 @@ -#ifndef flagNOSSL - -#include "WebSSL.h" - -NAMESPACE_UPP - -extern bool HttpClient_Trace__; - -#ifdef _DEBUG -#define LLOG(x) if(HttpClient_Trace__) RLOG(x); else; -#else -#define LLOG(x) -#endif - -HttpsClient::HttpsClient() -{ - secure = true; -} - -bool HttpsClient::ProxyConnect() -{ - if(use_proxy) { - int start_time = msecs(); - int end_time = msecs() + timeout_msecs; - while(!socket.PeekWrite(1000)) { - int time = msecs(); - if(time >= end_time) { - error = NFormat(t_("%s:%d: connecting to host timed out"), socket_host, socket_port); - Close(); - return false; - } - } - String host_port = host; - if(port) - host_port << ':' << port; - else - host_port << ":443"; - String request; - request << "CONNECT " << host_port << " HTTP/1.1\r\n" - << "Host: " << host_port << "\r\n"; - if(!IsNull(proxy_username)) - request << "Proxy-Authorization: Basic " - << Base64Encode(proxy_username + ':' + proxy_password) << "\r\n"; - request << "\r\n"; - LLOG(request); - int written = 0; - while(msecs() < end_time) { - int nwrite = socket.WriteWait(request.GetIter(written), min(request.GetLength() - written, 1000), 1000); - if(socket.IsError()) { - error = Socket::GetErrorText(); - Close(); - return false; - } - if((written += nwrite) >= request.GetLength()) - break; - } - if(written < request.GetLength()) { - error = NFormat(t_("%s:%d: timed out sending request to server"), host, port); - Close(); - return false; - } - String line = ReadUntilProgress('\n', start_time, end_time, false); - LLOG("P< " << line); - if(socket.IsError()) { - error = Socket::GetErrorText(); - Close(); - return false; - } - if(!line.StartsWith("HTTP") || line.Find(" 2") < 0) { - error = "Invalid proxy reply: " + line; - Close(); - return false; - } - while(line.GetCount()) { - line = ReadUntilProgress('\n', start_time, end_time, false); - if(*line.Last() == '\r') - line.Trim(line.GetCount() - 1); - LLOG("P< " << line << " len " << line.GetCount()); - if(socket.IsError()) { - error = Socket::GetErrorText(); - Close(); - return false; - } - } - use_proxy = false; - while(!socket.PeekWrite(1000)) { - int time = msecs(); - if(time >= end_time) { - error = NFormat(t_("%s:%d: connecting to host timed out"), socket_host, socket_port); - Close(); - return false; - } - } - } - return true; -} - -bool HttpsClient::IsSecure() -{ - return secure; -} - -bool HttpsClient::CreateClientSocket() -{ - if(!secure) - return HttpClient::CreateClientSocket(); - if(!ssl_context) { - ssl_context = new SSLContext; - if(!ssl_context->Create(const_cast(SSLv3_client_method()))) { - error = t_("Error creating SSL context."); - return false; - } - } - if(!SSLClientSocketUnsecured(socket, *ssl_context, socket_host, - socket_port ? socket_port : DEFAULT_HTTPS_PORT, true, NULL, 0, false)) { - error = Socket::GetErrorText(); - return false; - } - socket.Linger(0); - if(!ProxyConnect()) - return false; - SSLSecureSocket(socket); - return true; -} - -END_UPP_NAMESPACE - -#endif +#ifndef flagNOSSL + +#include "WebSSL.h" + +NAMESPACE_UPP + +extern bool HttpClient_Trace__; + +#ifdef _DEBUG +#define LLOG(x) if(HttpClient_Trace__) RLOG(x); else; +#else +#define LLOG(x) +#endif + +HttpsClient::HttpsClient() +{ + secure = true; +} + +bool HttpsClient::ProxyConnect() +{ + if(use_proxy) { + int start_time = msecs(); + int end_time = msecs() + timeout_msecs; + while(!socket.PeekWrite(1000)) { + int time = msecs(); + if(time >= end_time) { + error = NFormat(t_("%s:%d: connecting to host timed out"), socket_host, socket_port); + Close(); + return false; + } + } + String host_port = host; + if(port) + host_port << ':' << port; + else + host_port << ":443"; + String request; + request << "CONNECT " << host_port << " HTTP/1.1\r\n" + << "Host: " << host_port << "\r\n"; + if(!IsNull(proxy_username)) + request << "Proxy-Authorization: Basic " + << Base64Encode(proxy_username + ':' + proxy_password) << "\r\n"; + request << "\r\n"; + LLOG(request); + int written = 0; + while(msecs() - end_time < 0) { + int nwrite = socket.WriteWait(request.GetIter(written), min(request.GetLength() - written, 1000), 1000); + if(socket.IsError()) { + error = Socket::GetErrorText(); + Close(); + return false; + } + if((written += nwrite) >= request.GetLength()) + break; + } + if(written < request.GetLength()) { + error = NFormat(t_("%s:%d: timed out sending request to server"), host, port); + Close(); + return false; + } + String line = ReadUntilProgress('\n', start_time, end_time, false); + LLOG("P< " << line); + if(socket.IsError()) { + error = Socket::GetErrorText(); + Close(); + return false; + } + if(!line.StartsWith("HTTP") || line.Find(" 2") < 0) { + error = "Invalid proxy reply: " + line; + Close(); + return false; + } + while(line.GetCount()) { + line = ReadUntilProgress('\n', start_time, end_time, false); + if(*line.Last() == '\r') + line.Trim(line.GetCount() - 1); + LLOG("P< " << line << " len " << line.GetCount()); + if(socket.IsError()) { + error = Socket::GetErrorText(); + Close(); + return false; + } + } + use_proxy = false; + while(!socket.PeekWrite(1000)) { + int time = msecs(); + if(time >= end_time) { + error = NFormat(t_("%s:%d: connecting to host timed out"), socket_host, socket_port); + Close(); + return false; + } + } + } + return true; +} + +bool HttpsClient::IsSecure() +{ + return secure; +} + +bool HttpsClient::CreateClientSocket() +{ + if(!secure) + return HttpClient::CreateClientSocket(); + if(!ssl_context) { + ssl_context = new SSLContext; + if(!ssl_context->Create(const_cast(SSLv3_client_method()))) { + error = t_("Error creating SSL context."); + return false; + } + } + if(!SSLClientSocketUnsecured(socket, *ssl_context, socket_host, + socket_port ? socket_port : DEFAULT_HTTPS_PORT, true, NULL, 0, false)) { + error = Socket::GetErrorText(); + return false; + } + socket.Linger(0); + if(!ProxyConnect()) + return false; + SSLSecureSocket(socket); + return true; +} + +END_UPP_NAMESPACE + +#endif diff --git a/uppsrc/Web/httpcli.cpp b/uppsrc/Web/httpcli.cpp index 90932fd17..0535849e7 100644 --- a/uppsrc/Web/httpcli.cpp +++ b/uppsrc/Web/httpcli.cpp @@ -292,7 +292,7 @@ String HttpClient::Execute(Gate2 progress) LLOG("host = " << host << ", port = " << port); LLOG("request: " << request); int written = 0; - while(msecs() < end_time) { + while(msecs() - end_time < 0) { int nwrite = socket.WriteWait(request.GetIter(written), min(request.GetLength() - written, 1000), 1000); if(socket.IsError()) { error = Socket::GetErrorText(); diff --git a/uppsrc/Web/httpcli_old.cpp b/uppsrc/Web/httpcli_old.cpp index 159458008..0adacb5ae 100644 --- a/uppsrc/Web/httpcli_old.cpp +++ b/uppsrc/Web/httpcli_old.cpp @@ -177,7 +177,7 @@ String HttpClient::Execute(Gate2 progress) LLOG("host = " << host << ", port = " << port); LLOG("request: " << request); int written = 0; - while(msecs() < end_time) { + while(msecs() - end_time < 0) { int nwrite = socket.WriteWait(request.GetIter(written), request.GetLength() - written, 1000); if(socket.IsError()) { error = Socket::GetErrorText();