From 0fac8a45ff78e4ea84fc486b611cf5528aed3e84 Mon Sep 17 00:00:00 2001 From: cxl Date: Mon, 9 Apr 2012 17:55:28 +0000 Subject: [PATCH] .Core/SSL git-svn-id: svn://ultimatepp.org/upp/trunk@4766 f0d560ea-af0d-0410-9eb7-867de7ffcac7 --- uppsrc/Core/Http.cpp | 1186 ++++++++++---------- uppsrc/Core/SSL/SSL.upp | 4 + uppsrc/Core/SSL/Socket.cpp | 749 +++++++------ uppsrc/Core/Socket.cpp | 1555 +++++++++++++------------- uppsrc/Core/Stream.cpp | 4 +- uppsrc/Core/Stream.h | 4 +- uppsrc/Core/Web.h | 739 ++++++------ uppsrc/Core/src.tpp/Stream$en-us.tpp | 18 +- 8 files changed, 2169 insertions(+), 2090 deletions(-) diff --git a/uppsrc/Core/Http.cpp b/uppsrc/Core/Http.cpp index 35040bbbf..a02f1b16f 100644 --- a/uppsrc/Core/Http.cpp +++ b/uppsrc/Core/Http.cpp @@ -1,591 +1,595 @@ -#include "Core.h" - -NAMESPACE_UPP - -bool HttpRequest_Trace__; - -#define LLOG(x) do { if(HttpRequest_Trace__) RLOG(x); } while(0) - -#ifdef _DEBUG -_DBG_ -// #define ENDZIP -#endif - -void HttpRequest::Trace(bool b) -{ - HttpRequest_Trace__ = b; -} - -void HttpRequest::Init() -{ - port = 0; - proxy_port = 0; - max_header_size = 1000000; - max_content_size = 10000000; - max_redirects = 5; - max_retries = 3; - force_digest = false; - std_headers = true; - hasurlvar = false; - method = METHOD_GET; - phase = START; - redirect_count = 0; - retry_count = 0; - gzip = false; - WhenContent = callback(this, &HttpRequest::ContentOut); - chunk = 4096; - timeout = 120000; -} - -HttpRequest::HttpRequest() -{ - Init(); -} - -HttpRequest::HttpRequest(const char *url) -{ - Init(); - Url(url); -} - -HttpRequest& HttpRequest::Url(const char *u) -{ - const char *t = u; - while(*t && *t != '?') - if(*t++ == '/' && *t == '/') { - u = ++t; - break; - } - t = u; - while(*u && *u != ':' && *u != '/' && *u != '?') - u++; - if(*u == '?' && u[1]) - hasurlvar = true; - host = String(t, u); - port = 0; - if(*u == ':') - port = ScanInt(u + 1, &u); - path = u; - int q = path.Find('#'); - if(q >= 0) - path.Trim(q); - return *this; -} - -HttpRequest& HttpRequest::Proxy(const char *p) -{ - const char *t = p; - while(*p && *p != ':') - p++; - proxy_host = String(t, p); - proxy_port = 80; - if(*p++ == ':' && IsDigit(*p)) - proxy_port = ScanInt(p); - return *this; -} - -HttpRequest& HttpRequest::Post(const char *id, const String& data) -{ - POST(); - if(postdata.GetCount()) - postdata << '&'; - postdata << id << '=' << UrlEncode(data); - return *this; -} - -HttpRequest& HttpRequest::UrlVar(const char *id, const String& data) -{ - int c = *path.Last(); - if(hasurlvar && c != '&') - path << '&'; - if(!hasurlvar && c != '?') - path << '?'; - path << id << '=' << UrlEncode(data); - hasurlvar = true; - return *this; -} - -String HttpRequest::CalculateDigest(const String& authenticate) const -{ - const char *p = authenticate; - String realm, qop, nonce, opaque; - while(*p) { - if(!IsAlNum(*p)) { - p++; - continue; - } - else { - const char *b = p; - while(IsAlNum(*p)) - p++; - String var = ToLower(String(b, p)); - String value; - while(*p && (byte)*p <= ' ') - p++; - if(*p == '=') { - p++; - while(*p && (byte)*p <= ' ') - p++; - if(*p == '\"') { - p++; - while(*p && *p != '\"') - if(*p != '\\' || *++p) - value.Cat(*p++); - if(*p == '\"') - p++; - } - else { - b = p; - while(*p && *p != ',' && (byte)*p > ' ') - p++; - value = String(b, p); - } - } - if(var == "realm") - realm = value; - else if(var == "qop") - qop = value; - else if(var == "nonce") - nonce = value; - else if(var == "opaque") - opaque = value; - } - } - String hv1, hv2; - hv1 << username << ':' << realm << ':' << password; - String ha1 = MD5String(hv1); - hv2 << (method == METHOD_GET ? "GET" : method == METHOD_PUT ? "PUT" : method == METHOD_POST ? "POST" : "READ") - << ':' << path; - String ha2 = MD5String(hv2); - int nc = 1; - String cnonce = FormatIntHex(Random(), 8); - String hv; - hv << ha1 - << ':' << nonce - << ':' << FormatIntHex(nc, 8) - << ':' << cnonce - << ':' << qop << ':' << ha2; - String ha = MD5String(hv); - String auth; - auth << "username=" << AsCString(username) - << ", realm=" << AsCString(realm) - << ", nonce=" << AsCString(nonce) - << ", uri=" << AsCString(path) - << ", qop=" << AsCString(qop) - << ", nc=" << AsCString(FormatIntHex(nc, 8)) - << ", cnonce=" << cnonce - << ", response=" << AsCString(ha); - if(!IsNull(opaque)) - auth << ", opaque=" << AsCString(opaque); - return auth; -} - -HttpRequest& HttpRequest::Header(const char *id, const String& data) -{ - request_headers << id << ": " << data << "\r\n"; - return *this; -} - -void HttpRequest::HttpError(const char *s) -{ - if(IsError()) - return; - error = NFormat(t_("%s:%d: ") + String(s), host, port); - LLOG("HTTP ERROR: " << error); - Close(); -} - -void HttpRequest::StartPhase(int s) -{ - LLOG("Starting status " << s << ' ' << host); - phase = s; - data.Clear(); -} - -bool HttpRequest::Do() -{ - int c1, c2; - switch(phase) { - case START: - retry_count = 0; - redirect_count = 0; - start_time = msecs(); - Start(); - break; - case DNS: - Dns(); - break; - case REQUEST: - if(SendingData()) - break; - StartPhase(HEADER); - break; - case HEADER: - if(ReadingHeader()) - break; - StartBody(); - break; - case BODY: - if(ReadingBody()) - break; - Finish(); - break; - case CHUNK_HEADER: - ReadingChunkHeader(); - break; - case CHUNK_BODY: - if(ReadingBody()) - break; - c1 = Get(); - c2 = Get(); - if(c1 != '\r' || c2 != '\n') - HttpError("missing ending CRLF in chunked transfer"); - StartPhase(CHUNK_HEADER); - break; - case TRAILER: - if(ReadingHeader()) - break; - header.Parse(data); - Finish(); - break; - case FINISHED: - case FAILED: - return false; - default: - NEVER(); - } - - if(phase != FAILED) - if(IsSocketError() || IsError()) - phase = FAILED; - else - if(msecs() - start_time >= timeout) { - HttpError("connection timed out"); - phase = FAILED; - } - else - if(IsAbort()) { - HttpError("connection was aborted"); - phase = FAILED; - } - - if(phase == FAILED) { - if(retry_count++ < max_retries) { - LLOG("HTTP retry on error " << GetErrorDesc()); - StartRequest(); - } - } - return phase != FINISHED && phase != FAILED; -} - -void HttpRequest::Finish() -{ - if(gzip) { - #ifdef ENDZIP - body = GZDecompress(body); - if(body.IsVoid()) { - HttpError("gzip decompress at finish error"); - phase = FAILED; - return; - } - #else - z.End(); - if(z.IsError()) { - HttpError("gzip format error (finish)"); - phase = FAILED; - return; - } - #endif - } - Close(); - if(status_code == 401 && !IsNull(username)) { - String authenticate = header["www-authenticate"]; - if(authenticate.GetCount() && redirect_count++ < max_redirects) { - LLOG("HTTP auth digest"); - Digest(CalculateDigest(authenticate)); - StartRequest(); - return; - } - } - if(status_code >= 300 && status_code < 400) { - String url = GetRedirectUrl(); - if(url.GetCount() && redirect_count++ < max_redirects) { - LLOG("HTTP redirect " << url); - Url(url); - StartRequest(); - retry_count = 0; - return; - } - } - phase = FINISHED; - -// if(retry_count < 2) -// HttpError("Checking retry"); -} - -void HttpRequest::ReadingChunkHeader() -{ - for(;;) { - int c = Get(); - if(c < 0) - break; - else - if(c == '\n') { - int n = ScanInt(~data, NULL, 16); - LLOG("HTTP Chunk header: 0x" << data << " = " << n); - if(IsNull(n)) { - HttpError("invalid chunk header"); - break; - } - if(n == 0) { - StartPhase(TRAILER); - break; - } - count += n; - StartPhase(CHUNK_BODY); - break; - } - if(c != '\r') - data.Cat(c); - } -} - -String HttpRequest::GetRedirectUrl() -{ - String redirect_url = TrimLeft(header["location"]); - int q = redirect_url.Find('?'); - int p = path.Find('?'); - if(p >= 0 && q < 0) - redirect_url.Cat(path.Mid(p)); - return redirect_url; -} - -int HttpRequest::GetContentLength() -{ - return Nvl(ScanInt(header["content-length"]), -1); -} - -void HttpRequest::StartBody() -{ - LLOG("HTTP Header received: "); - LLOG(data); - header.Clear(); - if(!header.Parse(data)) { - HttpError("invalid HTTP header"); - return; - } - - if(!header.Response(protocol, status_code, reason_phrase)) { - HttpError("invalid HTTP response"); - return; - } - - LLOG("HTTP status code: " << status_code); - - count = GetContentLength(); - - if(count > 0) - body.Reserve(count); - - if(method == METHOD_HEAD) - phase = FINISHED; - else - if(header["transfer-encoding"] == "chunked") { - count = 0; - StartPhase(CHUNK_HEADER); - } - else - StartPhase(BODY); - body.Clear(); - bodylen = 0; - gzip = GetHeader("content-encoding") == "gzip"; - if(gzip) { - gzip = true; - z.WhenOut = callback(this, &HttpRequest::Out); - z.ChunkSize(chunk).GZip().Decompress(); - } -} - -void HttpRequest::ContentOut(const void *ptr, dword size) -{ - body.Cat((const char *)ptr, size); -} - -void HttpRequest::Out(const void *ptr, dword size) -{ - LLOG("HTTP Out " << size); - if(z.IsError()) - HttpError("gzip format error"); - int64 l = bodylen + size; - if(l > max_content_size) { - HttpError("content length exceeded " + AsString(max_content_size)); - phase = FAILED; - return; - } - WhenContent(ptr, size); - bodylen += size; -} - -bool HttpRequest::ReadingBody() -{ - LLOG("HTTP reading data " << count); - int n = chunk; - if(count >= 0) - n = min(n, count); - String s = Get(n); - if(s.GetCount() == 0) - return !IsEof() && count; -#ifndef ENDZIP - if(gzip) - z.Put(~s, s.GetCount()); - else -#endif - Out(~s, s.GetCount()); - if(count >= 0) { - count -= s.GetCount(); - return !IsEof() && count > 0; - } - return !IsEof(); -} - -void HttpRequest::Start() -{ - Close(); - ClearError(); - gzip = false; - z.Clear(); - - bool use_proxy = !IsNull(proxy_host); - - int p = use_proxy ? proxy_port : port; - if(!p) - p = DEFAULT_HTTP_PORT; - String h = use_proxy ? proxy_host : host; - if(IsNull(GetTimeout())) { - addrinfo.Execute(h, p); - StartRequest(); - } - else { - addrinfo.Start(h, p); - StartPhase(DNS); - } -} - -void HttpRequest::Dns() -{ - for(int i = 0; i <= Nvl(GetTimeout(), INT_MAX); i++) { - if(!addrinfo.InProgress()) { - StartRequest(); - return; - } - Sleep(1); - } -} - -void HttpRequest::StartRequest() -{ - if(!Connect(addrinfo)) - return; - - StartPhase(REQUEST); - count = 0; - String ctype = contenttype; - if((method == METHOD_POST || method == METHOD_PUT) && IsNull(ctype)) - ctype = "application/x-www-form-urlencoded"; - switch(method) { - case METHOD_GET: data << "GET "; break; - case METHOD_POST: data << "POST "; break; - case METHOD_PUT: data << "PUT "; break; - case METHOD_HEAD: data << "HEAD "; break; - default: NEVER(); // invalid method - } - String host_port = host; - if(port) - host_port << ':' << port; - String url; - url << "http://" << host_port << Nvl(path, "/"); - if(!IsNull(proxy_host)) - data << url; - else - data << Nvl(path, "/"); - data << " HTTP/1.1\r\n"; - if(std_headers) { - data << "URL: " << url << "\r\n" - << "Host: " << host_port << "\r\n" - << "Connection: close\r\n" - << "Accept: " << Nvl(accept, "*/*") << "\r\n" - << "Accept-Encoding: gzip\r\n" - << "Agent: " << Nvl(agent, "Ultimate++ HTTP client") << "\r\n"; - if(postdata.GetCount()) - data << "Content-Length: " << postdata.GetCount() << "\r\n"; - if(ctype.GetCount()) - data << "Content-Type: " << ctype << "\r\n"; - } - if(!IsNull(proxy_host) && !IsNull(proxy_username)) - data << "Proxy-Authorization: Basic " << Base64Encode(proxy_username + ':' + proxy_password) << "\r\n"; - if(!IsNull(digest)) - data << "Authorization: Digest " << digest << "\r\n"; - else - if(!force_digest && (!IsNull(username) || !IsNull(password))) - data << "Authorization: Basic " << Base64Encode(username + ":" + password) << "\r\n"; - data << request_headers << "\r\n" << postdata; // !!! POST PHASE !!! - LLOG("HTTP REQUEST " << host << ":" << port); - LLOG("HTTP request:\n" << data); -} - -bool HttpRequest::SendingData() -{ - for(;;) { - int n = min(2048, data.GetLength() - count); - n = Put(~data + count, n); - if(n == 0) - break; - count += n; - } - return count < data.GetLength(); -} - -bool HttpRequest::ReadingHeader() -{ - for(;;) { - int c = Get(); - if(c < 0) - return !IsEof(); - else - data.Cat(c); - if(data.GetCount() > 3) { - const char *h = data.Last(); - if(h[0] == '\n' && (h[-1] == '\r' && h[-2] == '\n' || h[-1] == '\n')) - return false; - } - if(data.GetCount() > max_header_size) { - HttpError("HTTP header exceeded " + AsString(max_header_size)); - return true; - } - } -} - -String HttpRequest::Execute() -{ - while(Do()); - return IsSuccess() ? GetContent() : String::GetVoid(); -} - -String HttpRequest::GetPhaseName() const -{ - static const char *m[] = { - "Start", - "Resolving host name", - "Sending request", - "Receiving header", - "Receiving content", - "Receiving chunk header", - "Receiving content chunk", - "Receiving trailer", - "Finished", - "Failed", - }; - return phase >= 0 && phase <= FAILED ? m[phase] : ""; -} - -END_UPP_NAMESPACE +#include "Core.h" + +NAMESPACE_UPP + +bool HttpRequest_Trace__; + +#define LLOG(x) do { if(HttpRequest_Trace__) RLOG(x); } while(0) + +#ifdef _DEBUG +_DBG_ +// #define ENDZIP +#endif + +void HttpRequest::Trace(bool b) +{ + HttpRequest_Trace__ = b; +} + +void HttpRequest::Init() +{ + port = 0; + proxy_port = 0; + max_header_size = 1000000; + max_content_size = 10000000; + max_redirects = 5; + max_retries = 3; + force_digest = false; + std_headers = true; + hasurlvar = false; + method = METHOD_GET; + phase = START; + redirect_count = 0; + retry_count = 0; + gzip = false; + WhenContent = callback(this, &HttpRequest::ContentOut); + chunk = 4096; + timeout = 120000; + ssl = false; +} + +HttpRequest::HttpRequest() +{ + Init(); +} + +HttpRequest::HttpRequest(const char *url) +{ + Init(); + Url(url); +} + +HttpRequest& HttpRequest::Url(const char *u) +{ + ssl = memcmp(u, "https", 5) == 0; + const char *t = u; + while(*t && *t != '?') + if(*t++ == '/' && *t == '/') { + u = ++t; + break; + } + t = u; + while(*u && *u != ':' && *u != '/' && *u != '?') + u++; + if(*u == '?' && u[1]) + hasurlvar = true; + host = String(t, u); + port = 0; + if(*u == ':') + port = ScanInt(u + 1, &u); + path = u; + int q = path.Find('#'); + if(q >= 0) + path.Trim(q); + return *this; +} + +HttpRequest& HttpRequest::Proxy(const char *p) +{ + const char *t = p; + while(*p && *p != ':') + p++; + proxy_host = String(t, p); + proxy_port = 80; + if(*p++ == ':' && IsDigit(*p)) + proxy_port = ScanInt(p); + return *this; +} + +HttpRequest& HttpRequest::Post(const char *id, const String& data) +{ + POST(); + if(postdata.GetCount()) + postdata << '&'; + postdata << id << '=' << UrlEncode(data); + return *this; +} + +HttpRequest& HttpRequest::UrlVar(const char *id, const String& data) +{ + int c = *path.Last(); + if(hasurlvar && c != '&') + path << '&'; + if(!hasurlvar && c != '?') + path << '?'; + path << id << '=' << UrlEncode(data); + hasurlvar = true; + return *this; +} + +String HttpRequest::CalculateDigest(const String& authenticate) const +{ + const char *p = authenticate; + String realm, qop, nonce, opaque; + while(*p) { + if(!IsAlNum(*p)) { + p++; + continue; + } + else { + const char *b = p; + while(IsAlNum(*p)) + p++; + String var = ToLower(String(b, p)); + String value; + while(*p && (byte)*p <= ' ') + p++; + if(*p == '=') { + p++; + while(*p && (byte)*p <= ' ') + p++; + if(*p == '\"') { + p++; + while(*p && *p != '\"') + if(*p != '\\' || *++p) + value.Cat(*p++); + if(*p == '\"') + p++; + } + else { + b = p; + while(*p && *p != ',' && (byte)*p > ' ') + p++; + value = String(b, p); + } + } + if(var == "realm") + realm = value; + else if(var == "qop") + qop = value; + else if(var == "nonce") + nonce = value; + else if(var == "opaque") + opaque = value; + } + } + String hv1, hv2; + hv1 << username << ':' << realm << ':' << password; + String ha1 = MD5String(hv1); + hv2 << (method == METHOD_GET ? "GET" : method == METHOD_PUT ? "PUT" : method == METHOD_POST ? "POST" : "READ") + << ':' << path; + String ha2 = MD5String(hv2); + int nc = 1; + String cnonce = FormatIntHex(Random(), 8); + String hv; + hv << ha1 + << ':' << nonce + << ':' << FormatIntHex(nc, 8) + << ':' << cnonce + << ':' << qop << ':' << ha2; + String ha = MD5String(hv); + String auth; + auth << "username=" << AsCString(username) + << ", realm=" << AsCString(realm) + << ", nonce=" << AsCString(nonce) + << ", uri=" << AsCString(path) + << ", qop=" << AsCString(qop) + << ", nc=" << AsCString(FormatIntHex(nc, 8)) + << ", cnonce=" << cnonce + << ", response=" << AsCString(ha); + if(!IsNull(opaque)) + auth << ", opaque=" << AsCString(opaque); + return auth; +} + +HttpRequest& HttpRequest::Header(const char *id, const String& data) +{ + request_headers << id << ": " << data << "\r\n"; + return *this; +} + +void HttpRequest::HttpError(const char *s) +{ + if(IsError()) + return; + error = NFormat(t_("%s:%d: ") + String(s), host, port); + LLOG("HTTP ERROR: " << error); + Close(); +} + +void HttpRequest::StartPhase(int s) +{ + LLOG("Starting status " << s << ' ' << host); + phase = s; + data.Clear(); +} + +bool HttpRequest::Do() +{ + int c1, c2; + switch(phase) { + case START: + retry_count = 0; + redirect_count = 0; + start_time = msecs(); + Start(); + break; + case DNS: + Dns(); + break; + case REQUEST: + if(SendingData()) + break; + StartPhase(HEADER); + break; + case HEADER: + if(ReadingHeader()) + break; + StartBody(); + break; + case BODY: + if(ReadingBody()) + break; + Finish(); + break; + case CHUNK_HEADER: + ReadingChunkHeader(); + break; + case CHUNK_BODY: + if(ReadingBody()) + break; + c1 = Get(); + c2 = Get(); + if(c1 != '\r' || c2 != '\n') + HttpError("missing ending CRLF in chunked transfer"); + StartPhase(CHUNK_HEADER); + break; + case TRAILER: + if(ReadingHeader()) + break; + header.Parse(data); + Finish(); + break; + case FINISHED: + case FAILED: + return false; + default: + NEVER(); + } + + if(phase != FAILED) + if(IsSocketError() || IsError()) + phase = FAILED; + else + if(msecs() - start_time >= timeout) { + HttpError("connection timed out"); + phase = FAILED; + } + else + if(IsAbort()) { + HttpError("connection was aborted"); + phase = FAILED; + } + + if(phase == FAILED) { + if(retry_count++ < max_retries) { + LLOG("HTTP retry on error " << GetErrorDesc()); + StartRequest(); + } + } + return phase != FINISHED && phase != FAILED; +} + +void HttpRequest::Finish() +{ + if(gzip) { + #ifdef ENDZIP + body = GZDecompress(body); + if(body.IsVoid()) { + HttpError("gzip decompress at finish error"); + phase = FAILED; + return; + } + #else + z.End(); + if(z.IsError()) { + HttpError("gzip format error (finish)"); + phase = FAILED; + return; + } + #endif + } + Close(); + if(status_code == 401 && !IsNull(username)) { + String authenticate = header["www-authenticate"]; + if(authenticate.GetCount() && redirect_count++ < max_redirects) { + LLOG("HTTP auth digest"); + Digest(CalculateDigest(authenticate)); + StartRequest(); + return; + } + } + if(status_code >= 300 && status_code < 400) { + String url = GetRedirectUrl(); + if(url.GetCount() && redirect_count++ < max_redirects) { + LLOG("HTTP redirect " << url); + Url(url); + StartRequest(); + retry_count = 0; + return; + } + } + phase = FINISHED; + +// if(retry_count < 2) +// HttpError("Checking retry"); +} + +void HttpRequest::ReadingChunkHeader() +{ + for(;;) { + int c = Get(); + if(c < 0) + break; + else + if(c == '\n') { + int n = ScanInt(~data, NULL, 16); + LLOG("HTTP Chunk header: 0x" << data << " = " << n); + if(IsNull(n)) { + HttpError("invalid chunk header"); + break; + } + if(n == 0) { + StartPhase(TRAILER); + break; + } + count += n; + StartPhase(CHUNK_BODY); + break; + } + if(c != '\r') + data.Cat(c); + } +} + +String HttpRequest::GetRedirectUrl() +{ + String redirect_url = TrimLeft(header["location"]); + int q = redirect_url.Find('?'); + int p = path.Find('?'); + if(p >= 0 && q < 0) + redirect_url.Cat(path.Mid(p)); + return redirect_url; +} + +int HttpRequest::GetContentLength() +{ + return Nvl(ScanInt(header["content-length"]), -1); +} + +void HttpRequest::StartBody() +{ + LLOG("HTTP Header received: "); + LLOG(data); + header.Clear(); + if(!header.Parse(data)) { + HttpError("invalid HTTP header"); + return; + } + + if(!header.Response(protocol, status_code, reason_phrase)) { + HttpError("invalid HTTP response"); + return; + } + + LLOG("HTTP status code: " << status_code); + + count = GetContentLength(); + + if(count > 0) + body.Reserve(count); + + if(method == METHOD_HEAD) + phase = FINISHED; + else + if(header["transfer-encoding"] == "chunked") { + count = 0; + StartPhase(CHUNK_HEADER); + } + else + StartPhase(BODY); + body.Clear(); + bodylen = 0; + gzip = GetHeader("content-encoding") == "gzip"; + if(gzip) { + gzip = true; + z.WhenOut = callback(this, &HttpRequest::Out); + z.ChunkSize(chunk).GZip().Decompress(); + } +} + +void HttpRequest::ContentOut(const void *ptr, dword size) +{ + body.Cat((const char *)ptr, size); +} + +void HttpRequest::Out(const void *ptr, dword size) +{ + LLOG("HTTP Out " << size); + if(z.IsError()) + HttpError("gzip format error"); + int64 l = bodylen + size; + if(l > max_content_size) { + HttpError("content length exceeded " + AsString(max_content_size)); + phase = FAILED; + return; + } + WhenContent(ptr, size); + bodylen += size; +} + +bool HttpRequest::ReadingBody() +{ + LLOG("HTTP reading data " << count); + int n = chunk; + if(count >= 0) + n = min(n, count); + String s = Get(n); + if(s.GetCount() == 0) + return !IsEof() && count; +#ifndef ENDZIP + if(gzip) + z.Put(~s, s.GetCount()); + else +#endif + Out(~s, s.GetCount()); + if(count >= 0) { + count -= s.GetCount(); + return !IsEof() && count > 0; + } + return !IsEof(); +} + +void HttpRequest::Start() +{ + Close(); + ClearError(); + gzip = false; + z.Clear(); + + bool use_proxy = !IsNull(proxy_host); + + int p = use_proxy ? proxy_port : port; + if(!p) + p = DEFAULT_HTTP_PORT; + String h = use_proxy ? proxy_host : host; + if(IsNull(GetTimeout())) { + addrinfo.Execute(h, p); + StartRequest(); + } + else { + addrinfo.Start(h, p); + StartPhase(DNS); + } +} + +void HttpRequest::Dns() +{ + for(int i = 0; i <= Nvl(GetTimeout(), INT_MAX); i++) { + if(!addrinfo.InProgress()) { + StartRequest(); + return; + } + Sleep(1); + } +} + +void HttpRequest::StartRequest() +{ + if(!Connect(addrinfo)) + return; + + StartPhase(REQUEST); + count = 0; + String ctype = contenttype; + if((method == METHOD_POST || method == METHOD_PUT) && IsNull(ctype)) + ctype = "application/x-www-form-urlencoded"; + switch(method) { + case METHOD_GET: data << "GET "; break; + case METHOD_POST: data << "POST "; break; + case METHOD_PUT: data << "PUT "; break; + case METHOD_HEAD: data << "HEAD "; break; + default: NEVER(); // invalid method + } + String host_port = host; + if(port) + host_port << ':' << port; + String url; + url << "http://" << host_port << Nvl(path, "/"); + if(!IsNull(proxy_host)) + data << url; + else + data << Nvl(path, "/"); + data << " HTTP/1.1\r\n"; + if(std_headers) { + data << "URL: " << url << "\r\n" + << "Host: " << host_port << "\r\n" + << "Connection: close\r\n" + << "Accept: " << Nvl(accept, "*/*") << "\r\n" + << "Accept-Encoding: gzip\r\n" + << "Agent: " << Nvl(agent, "Ultimate++ HTTP client") << "\r\n"; + if(postdata.GetCount()) + data << "Content-Length: " << postdata.GetCount() << "\r\n"; + if(ctype.GetCount()) + data << "Content-Type: " << ctype << "\r\n"; + } + if(!IsNull(proxy_host) && !IsNull(proxy_username)) + data << "Proxy-Authorization: Basic " << Base64Encode(proxy_username + ':' + proxy_password) << "\r\n"; + if(!IsNull(digest)) + data << "Authorization: Digest " << digest << "\r\n"; + else + if(!force_digest && (!IsNull(username) || !IsNull(password))) + data << "Authorization: Basic " << Base64Encode(username + ":" + password) << "\r\n"; + data << request_headers << "\r\n" << postdata; // !!! POST PHASE !!! + LLOG("HTTP REQUEST " << host << ":" << port); + LLOG("HTTP request:\n" << data); + if(ssl) + StartSSL(); +} + +bool HttpRequest::SendingData() +{ + for(;;) { + int n = min(2048, data.GetLength() - count); + n = Put(~data + count, n); + if(n == 0) + break; + count += n; + } + return count < data.GetLength(); +} + +bool HttpRequest::ReadingHeader() +{ + for(;;) { + int c = Get(); + if(c < 0) + return !IsEof(); + else + data.Cat(c); + if(data.GetCount() > 3) { + const char *h = data.Last(); + if(h[0] == '\n' && (h[-1] == '\r' && h[-2] == '\n' || h[-1] == '\n')) + return false; + } + if(data.GetCount() > max_header_size) { + HttpError("HTTP header exceeded " + AsString(max_header_size)); + return true; + } + } +} + +String HttpRequest::Execute() +{ + while(Do()); + return IsSuccess() ? GetContent() : String::GetVoid(); +} + +String HttpRequest::GetPhaseName() const +{ + static const char *m[] = { + "Start", + "Resolving host name", + "Sending request", + "Receiving header", + "Receiving content", + "Receiving chunk header", + "Receiving content chunk", + "Receiving trailer", + "Finished", + "Failed", + }; + return phase >= 0 && phase <= FAILED ? m[phase] : ""; +} + +END_UPP_NAMESPACE diff --git a/uppsrc/Core/SSL/SSL.upp b/uppsrc/Core/SSL/SSL.upp index ef48d5bf5..d40b3064f 100644 --- a/uppsrc/Core/SSL/SSL.upp +++ b/uppsrc/Core/SSL/SSL.upp @@ -1,3 +1,7 @@ +description "\3770,128,128"; + +library(POSIX) "crypto ssl"; + file SSL.h, Util.cpp, diff --git a/uppsrc/Core/SSL/Socket.cpp b/uppsrc/Core/SSL/Socket.cpp index 83639ecab..35602aace 100644 --- a/uppsrc/Core/SSL/Socket.cpp +++ b/uppsrc/Core/SSL/Socket.cpp @@ -1,343 +1,406 @@ -#include "SSL.h" - -NAMESPACE_UPP - -struct TcpSocket::SSLImp : TcpSocket::SSL { - virtual bool Start(TcpSocket& s); - virtual bool Wait(TcpSocket& s, dword flags); - virtual int Send(TcpSocket& s, const void *buffer, int maxlen); - virtual int Recv(TcpSocket& s, void *buffer, int maxlen); - virtual void Close(TcpSocket& s); -}; - -TcpSocket::SSLImp *TcpSocket::CreateSSLImp() -{ - return new TcpSSLImp(); -} - -void InitCreateSSL() -{ - TcpTcpSocket::CreateSSL = sCreate(); -} - -bool TcpSocket::SSLImp::Start(TcpSocket& socket) -{ - if(!(ssl = SSL_new(ssl_context))) { - SetSSLError("OpenClient / SSL_new"); - return false; - } - if(!SSL_set_fd(ssl, socket)) { - SetSSLError("OpenClient / SSL_set_fd"); - return false; - } - int res; - if(mode == ACCEPT) { - SSL_set_accept_state(ssl); - res = SSL_accept(ssl); - } - else { - SSL_set_connect_state(ssl); - res = SSL_connect(ssl); - } - if(res <= 0) { - SetSSLResError("OpenClient / SSL_connect", res); - return false; - } - cert.Set(SSL_get_peer_certificate(ssl)); - return true; -} - -bool TcpSocket::SSLImp::Wait(TcpSocket& s, dword flags) -{ - if((flags & WAIT_READ) && SSL_pending(ssl) > 0) - return true; - return s.RawWait(flags); -} - -int TcpSocket::SSLImp::Send(TcpSocket& s, const void *buffer, int maxlen) -{ - if(!ssl) - return Data::Write(buf, amount); - int res = SSL_write(ssl, (const char *)buf, amount); - if(res <= 0) { - SetSSLResError("SSL_write", res); - return 0; - } - return res; -} - -int TcpSocket::SSLImp::Recv(TcpSocket& s, void *buffer, int maxlen) -{ - int res = SSL_read(ssl, (char *)buf, amount); - if(res == 0) { - s.is_eof = true; - if(SSL_get_shutdown(ssl) & SSL_RECEIVED_SHUTDOWN) - return 0; - } - if(res <= 0) { - SetSSLResError("SSL_read", res); - return 0; - } - return res; -} - -void TcpSocket::SSLImp::Close(TcpSocket& s) -{ - SSL_shutdown(ssl); - s.RawClose(); - SSL_free(ssl); -} - - -#if 0 -class SSLSocketData : public TcpSocket::Data -{ -public: - SSLSocketData(SslContext& context); - virtual ~SSLSocketData(); - - bool OpenClientUnsecured(const char *host, int port, bool nodelay, dword *my_addr, - int timeout, bool is_blocking); - bool Secure(); - bool OpenAccept(SOCKET connection, bool nodelay, bool blocking); - - virtual int GetKind() const { return SOCKKIND_SSL; } - virtual bool Peek(int timeout_msec, bool write); - virtual int Read(void *buf, int amount); - virtual int Write(const void *buf, int amount); - virtual bool Accept(Socket& socket, dword *ipaddr, bool nodelay, int timeout_msec); - virtual bool Close(int timeout_msec); - virtual Value GetInfo(String info) const; - - void SetSSLError(const char *context); - void SetSSLResError(const char *context, int res); - -public: - SslContext& ssl_context; - SSL *ssl; - SslCertificate cert; -}; - -SSLSocketData::SSLSocketData(SslContext& ssl_context) -: ssl_context(ssl_context) -{ - SSLInit().AddThread(); - ssl = NULL; -} - -SSLSocketData::~SSLSocketData() -{ - Close(0); -} - -void SSLSocketData::SetSSLError(const char *context) -{ - if(sock) { - int code; - String text = SSLGetLastError(code); - SetSockError(context, code, text); - } -} - -void SSLSocketData::SetSSLResError(const char *context, int res) -{ - if(sock) { - int code = SSL_get_error(ssl, res); - String out; - switch(code) { - #define SSLERR(c) case c: out = #c; break; - SSLERR(SSL_ERROR_NONE) - SSLERR(SSL_ERROR_SSL) - SSLERR(SSL_ERROR_WANT_READ) - SSLERR(SSL_ERROR_WANT_WRITE) - SSLERR(SSL_ERROR_WANT_X509_LOOKUP) - SSLERR(SSL_ERROR_SYSCALL) - SSLERR(SSL_ERROR_ZERO_RETURN) - SSLERR(SSL_ERROR_WANT_CONNECT) - #ifdef PLATFORM_WIN32 - SSLERR(SSL_ERROR_WANT_ACCEPT) - #endif - default: out = "unknown code"; break; - } - SetSockError(context, code, out); - } -} - -bool SSLSocketData::Peek(int timeout_msec, bool write) -{ - if(ssl && !write && SSL_pending(ssl) > 0) - return true; - return Data::Peek(timeout_msec, write); -} - -int SSLSocketData::Read(void *buf, int amount) -{ - if(!ssl) - return Data::Read(buf, amount); - int res = SSL_read(ssl, (char *)buf, amount); - if(res == 0) { - is_eof = true; - if(SSL_get_shutdown(ssl) & SSL_RECEIVED_SHUTDOWN) - return 0; - } - if(res <= 0) - SetSSLResError("SSL_read", res); -#ifndef NOFAKEERROR - if(fake_error && res > 0) { - if((fake_error -= res) <= 0) { - fake_error = 0; - SetSockError("SSL_read", 0, "fake error"); - return -1; - } - else - RLOG("SSLSocketData::Read: fake error after " << fake_error); - } -#endif - return res; -} - -int SSLSocketData::Write(const void *buf, int amount) -{ - if(!ssl) - return Data::Write(buf, amount); - int res = SSL_write(ssl, (const char *)buf, amount); - if(res <= 0) - SetSSLResError("SSL_write", res); - return res; -} - -bool SSLSocketData::OpenClientUnsecured(const char *host, int port, bool nodelay, dword *my_addr, - int timeout, bool blocking) -{ - return Data::OpenClient(host, port, nodelay, my_addr, timeout, /*blocking*/true); -} - -bool SSLSocketData::Secure() -{ - if(!(ssl = SSL_new(ssl_context))) - { - SetSSLError("OpenClient / SSL_new"); - return false; - } - if(!SSL_set_fd(ssl, socket)) - { - SetSSLError("OpenClient / SSL_set_fd"); - return false; - } - SSL_set_connect_state(ssl); - int res = SSL_connect(ssl); - if(res <= 0) - { - SetSSLResError("OpenClient / SSL_connect", res); - return false; - } - cert.Set(SSL_get_peer_certificate(ssl)); - return true; -} - -bool SSLSocketData::OpenAccept(SOCKET conn, bool nodelay, bool blocking) -{ - Attach(conn, nodelay, blocking); - if(!(ssl = SSL_new(ssl_context))) - { - SetSSLError("Accept / SSL_new"); - return false; - } - if(!SSL_set_fd(ssl, socket)) - { - SetSSLError("Accept / SSL_set_fd"); - return false; - } - SSL_set_accept_state(ssl); - int res = SSL_accept(ssl); - if(res <= 0) - { - SetSSLResError("Accept / SSL_accept", res); - return false; - } - cert.Set(SSL_get_peer_certificate(ssl)); - return true; -} - -bool SSLSocketData::Accept(Socket& socket, dword *ipaddr, bool nodelay, int timeout_msec) -{ - SOCKET connection = AcceptRaw(ipaddr, timeout_msec); - if(connection == INVALID_SOCKET) - return false; - One data = new SSLSocketData(ssl_context); - if(!data->OpenAccept(connection, nodelay, is_blocking)) - return false; - socket.Attach(-data); - return true; -} - -bool SSLSocketData::Close(int timeout_msec) -{ - if(ssl) - SSL_shutdown(ssl); - bool res = Data::Close(timeout_msec); - if(ssl) { - SSL_free(ssl); - ssl = NULL; - } - return res; -} - -Value SSLSocketData::GetInfo(String info) const -{ - if(info == SSLInfoCipher()) return SSL_get_cipher(ssl); - if(info == SSLInfoCertAvail()) return cert.IsEmpty() ? 0 : 1; - if(info == SSLInfoCertVerified()) return SSL_get_verify_result(ssl) == X509_V_OK ? 1 : 0; - if(info == SSLInfoCertSubjectName()) return cert.IsEmpty() ? String::GetVoid() : cert.GetSubjectName(); - if(info == SSLInfoCertIssuerName()) return cert.IsEmpty() ? String::GetVoid() : cert.GetIssuerName(); - if(info == SSLInfoCertNotBefore()) return cert.IsEmpty() ? Date(Null) : cert.GetNotBefore(); - if(info == SSLInfoCertNotAfter()) return cert.IsEmpty() ? Date(Null) : cert.GetNotAfter(); - if(info == SSLInfoCertVersion()) return cert.IsEmpty() ? int(Null) : cert.GetVersion(); - if(info == SSLInfoCertSerialNumber()) return cert.IsEmpty() ? String::GetVoid() : cert.GetSerialNumber(); - - return Data::GetInfo(info); -} - -bool SSLServerSocket(Socket& socket, SslContext& ssl_context, int port, bool nodelay, int listen_count, bool blocking) -{ - One data = new SSLSocketData(ssl_context); - if(!data->OpenServer(port, nodelay, listen_count, blocking)) - return false; - socket.Attach(-data); - return true; -} - -bool SSLClientSocket(Socket& socket, SslContext& ssl_context, const char *host, int port, bool nodelay, - dword *my_addr, int timeout, bool blocking) -{ - One data = new SSLSocketData(ssl_context); - if(!data->OpenClient(host, port, nodelay, my_addr, timeout, blocking)) - return false; - if(!data->Secure()) - return false; - socket.Attach(-data); - return true; -} - -bool SSLClientSocketUnsecured(Socket& socket, SslContext& ssl_context, const char *host, - int port, bool nodelay, dword *my_addr, int timeout, - bool is_blocking) -{ - One data = new SSLSocketData(ssl_context); - if(data->OpenClientUnsecured(host, port, nodelay, my_addr, timeout, is_blocking)) { - socket.Attach(-data); - return true; - } - return false; -} - -bool SSLSecureSocket(Socket& socket) -{ - SSLSocketData *sd = dynamic_cast(~socket.data); - if(!sd) - return false; - return sd->Secure(); -} -#endif - -END_UPP_NAMESPACE +#include "SSL.h" + +NAMESPACE_UPP + +struct TcpSocket::SSLImp : TcpSocket::SSL { + virtual bool Start(); + virtual bool Wait(dword flags); + virtual int Send(const void *buffer, int maxlen); + virtual int Recv(void *buffer, int maxlen); + virtual void Close(); + + TcpSocket& socket; + SslContext context; + ::SSL *ssl; + SslCertificate cert; + + void SetSSLError(const char *context); + void SetSSLResError(const char *context, int res); + bool IsAgain(int res) const; + + SSLImp(TcpSocket& socket) : socket(socket) {} +}; + +TcpSocket::SSL *TcpSocket::CreateSSLImp(TcpSocket& socket) +{ + return new TcpSocket::SSLImp(socket); +} + +void InitCreateSSL() +{ + TcpSocket::CreateSSL = TcpSocket::CreateSSLImp; +} + +INITBLOCK { + InitCreateSSL(); +} + +void TcpSocket::SSLImp::SetSSLError(const char *context) +{ + int code; + String text = SslGetLastError(code); + socket.SetSockError(context, code, text); +} + +void TcpSocket::SSLImp::SetSSLResError(const char *context, int res) +{ + int code = SSL_get_error(ssl, res); + String out; + switch(code) { +#define SSLERR(c) case c: out = #c; break; + SSLERR(SSL_ERROR_NONE) + SSLERR(SSL_ERROR_SSL) + SSLERR(SSL_ERROR_WANT_READ) + SSLERR(SSL_ERROR_WANT_WRITE) + SSLERR(SSL_ERROR_WANT_X509_LOOKUP) + SSLERR(SSL_ERROR_SYSCALL) + SSLERR(SSL_ERROR_ZERO_RETURN) + SSLERR(SSL_ERROR_WANT_CONNECT) +#ifdef PLATFORM_WIN32 + SSLERR(SSL_ERROR_WANT_ACCEPT) +#endif + default: out = "unknown code"; break; + } + socket.SetSockError(context, code, out); +} + +bool TcpSocket::SSLImp::IsAgain(int res) const +{ + res = SSL_get_error(ssl, res); + return res == SSL_ERROR_WANT_READ || + res == SSL_ERROR_WANT_WRITE || + res == SSL_ERROR_WANT_CONNECT || + res == SSL_ERROR_WANT_ACCEPT; +} + +bool TcpSocket::SSLImp::Start() +{ + if(!context.Create(const_cast(SSLv3_client_method()))) { + SetSSLError("Start: SSL context."); + return false; + } + if(!(ssl = SSL_new(context))) { + SetSSLError("Start: SSL_new"); + return false; + } + if(!SSL_set_fd(ssl, socket.GetSOCKET())) { + SetSSLError("Start: SSL_set_fd"); + return false; + } + int res; + if(socket.mode == ACCEPT) { + SSL_set_accept_state(ssl); + int res = SSL_accept(ssl); + if(res <= 0 && !IsAgain(res)) { + SetSSLResError("Start: SSL_accept", res); + return false; + } + } + else { + SSL_set_connect_state(ssl); + res = SSL_connect(ssl); + if(res <= 0 && !IsAgain(res)) { + SetSSLResError("Start: SSL_connect", res); + return false; + } + } + cert.Set(SSL_get_peer_certificate(ssl)); + return true; +} + +bool TcpSocket::SSLImp::Wait(dword flags) +{ + if((flags & WAIT_READ) && SSL_pending(ssl) > 0) + return true; + return socket.RawWait(flags); +} + +int TcpSocket::SSLImp::Send(const void *buffer, int maxlen) +{ + int res = SSL_write(ssl, (const char *)buffer, maxlen); + if(IsAgain(res)) + return 0; + if(res <= 0) { + SetSSLResError("SSL_write", res); + return 0; + } + return res; +} + +int TcpSocket::SSLImp::Recv(void *buffer, int maxlen) +{ + int res = SSL_read(ssl, (char *)buffer, maxlen); + if(IsAgain(res)) + return 0; + if(res == 0) { + socket.is_eof = true; + if(SSL_get_shutdown(ssl) & SSL_RECEIVED_SHUTDOWN) + return 0; + } + if(res <= 0) { + SetSSLResError("SSL_read", res); + return 0; + } + return res; +} + +void TcpSocket::SSLImp::Close() +{ + SSL_shutdown(ssl); + socket.RawClose(); + SSL_free(ssl); +} + + +#if 0 +class SSLSocketData : public TcpSocket::Data +{ +public: + SSLSocketData(SslContext& context); + virtual ~SSLSocketData(); + + bool OpenClientUnsecured(const char *host, int port, bool nodelay, dword *my_addr, + int timeout, bool is_blocking); + bool Secure(); + bool OpenAccept(SOCKET connection, bool nodelay, bool blocking); + + virtual int GetKind() const { return SOCKKIND_SSL; } + virtual bool Peek(int timeout_msec, bool write); + virtual int Read(void *buf, int amount); + virtual int Write(const void *buf, int amount); + virtual bool Accept(Socket& socket, dword *ipaddr, bool nodelay, int timeout_msec); + virtual bool Close(int timeout_msec); + virtual Value GetInfo(String info) const; + + void SetSSLError(const char *context); + void SetSSLResError(const char *context, int res); + +public: + SslContext& ssl_context; + SSL *ssl; + SslCertificate cert; +}; + +SSLSocketData::SSLSocketData(SslContext& ssl_context) +: ssl_context(ssl_context) +{ + SSLInit().AddThread(); + ssl = NULL; +} + +SSLSocketData::~SSLSocketData() +{ + Close(0); +} + +void SSLSocketData::SetSSLError(const char *context) +{ + if(sock) { + int code; + String text = SSLGetLastError(code); + SetSockError(context, code, text); + } +} + +void SSLSocketData::SetSSLResError(const char *context, int res) +{ + if(sock) { + int code = SSL_get_error(ssl, res); + String out; + switch(code) { + #define SSLERR(c) case c: out = #c; break; + SSLERR(SSL_ERROR_NONE) + SSLERR(SSL_ERROR_SSL) + SSLERR(SSL_ERROR_WANT_READ) + SSLERR(SSL_ERROR_WANT_WRITE) + SSLERR(SSL_ERROR_WANT_X509_LOOKUP) + SSLERR(SSL_ERROR_SYSCALL) + SSLERR(SSL_ERROR_ZERO_RETURN) + SSLERR(SSL_ERROR_WANT_CONNECT) + #ifdef PLATFORM_WIN32 + SSLERR(SSL_ERROR_WANT_ACCEPT) + #endif + default: out = "unknown code"; break; + } + SetSockError(context, code, out); + } +} + +bool SSLSocketData::Peek(int timeout_msec, bool write) +{ + if(ssl && !write && SSL_pending(ssl) > 0) + return true; + return Data::Peek(timeout_msec, write); +} + +int SSLSocketData::Read(void *buf, int amount) +{ + if(!ssl) + return Data::Read(buf, amount); + int res = SSL_read(ssl, (char *)buf, amount); + if(res == 0) { + is_eof = true; + if(SSL_get_shutdown(ssl) & SSL_RECEIVED_SHUTDOWN) + return 0; + } + if(res <= 0) + SetSSLResError("SSL_read", res); +#ifndef NOFAKEERROR + if(fake_error && res > 0) { + if((fake_error -= res) <= 0) { + fake_error = 0; + SetSockError("SSL_read", 0, "fake error"); + return -1; + } + else + RLOG("SSLSocketData::Read: fake error after " << fake_error); + } +#endif + return res; +} + +int SSLSocketData::Write(const void *buf, int amount) +{ + if(!ssl) + return Data::Write(buf, amount); + int res = SSL_write(ssl, (const char *)buf, amount); + if(res <= 0) + SetSSLResError("SSL_write", res); + return res; +} + +bool SSLSocketData::OpenClientUnsecured(const char *host, int port, bool nodelay, dword *my_addr, + int timeout, bool blocking) +{ + return Data::OpenClient(host, port, nodelay, my_addr, timeout, /*blocking*/true); +} + +bool SSLSocketData::Secure() +{ + if(!(ssl = SSL_new(ssl_context))) + { + SetSSLError("OpenClient / SSL_new"); + return false; + } + if(!SSL_set_fd(ssl, socket)) + { + SetSSLError("OpenClient / SSL_set_fd"); + return false; + } + SSL_set_connect_state(ssl); + int res = SSL_connect(ssl); + if(res <= 0) + { + SetSSLResError("OpenClient / SSL_connect", res); + return false; + } + cert.Set(SSL_get_peer_certificate(ssl)); + return true; +} + +bool SSLSocketData::OpenAccept(SOCKET conn, bool nodelay, bool blocking) +{ + Attach(conn, nodelay, blocking); + if(!(ssl = SSL_new(ssl_context))) + { + SetSSLError("Accept / SSL_new"); + return false; + } + if(!SSL_set_fd(ssl, socket)) + { + SetSSLError("Accept / SSL_set_fd"); + return false; + } + SSL_set_accept_state(ssl); + int res = SSL_accept(ssl); + if(res <= 0) + { + SetSSLResError("Accept / SSL_accept", res); + return false; + } + cert.Set(SSL_get_peer_certificate(ssl)); + return true; +} + +bool SSLSocketData::Accept(Socket& socket, dword *ipaddr, bool nodelay, int timeout_msec) +{ + SOCKET connection = AcceptRaw(ipaddr, timeout_msec); + if(connection == INVALID_SOCKET) + return false; + One data = new SSLSocketData(ssl_context); + if(!data->OpenAccept(connection, nodelay, is_blocking)) + return false; + socket.Attach(-data); + return true; +} + +bool SSLSocketData::Close(int timeout_msec) +{ + if(ssl) + SSL_shutdown(ssl); + bool res = Data::Close(timeout_msec); + if(ssl) { + SSL_free(ssl); + ssl = NULL; + } + return res; +} + +Value SSLSocketData::GetInfo(String info) const +{ + if(info == SSLInfoCipher()) return SSL_get_cipher(ssl); + if(info == SSLInfoCertAvail()) return cert.IsEmpty() ? 0 : 1; + if(info == SSLInfoCertVerified()) return SSL_get_verify_result(ssl) == X509_V_OK ? 1 : 0; + if(info == SSLInfoCertSubjectName()) return cert.IsEmpty() ? String::GetVoid() : cert.GetSubjectName(); + if(info == SSLInfoCertIssuerName()) return cert.IsEmpty() ? String::GetVoid() : cert.GetIssuerName(); + if(info == SSLInfoCertNotBefore()) return cert.IsEmpty() ? Date(Null) : cert.GetNotBefore(); + if(info == SSLInfoCertNotAfter()) return cert.IsEmpty() ? Date(Null) : cert.GetNotAfter(); + if(info == SSLInfoCertVersion()) return cert.IsEmpty() ? int(Null) : cert.GetVersion(); + if(info == SSLInfoCertSerialNumber()) return cert.IsEmpty() ? String::GetVoid() : cert.GetSerialNumber(); + + return Data::GetInfo(info); +} + +bool SSLServerSocket(Socket& socket, SslContext& ssl_context, int port, bool nodelay, int listen_count, bool blocking) +{ + One data = new SSLSocketData(ssl_context); + if(!data->OpenServer(port, nodelay, listen_count, blocking)) + return false; + socket.Attach(-data); + return true; +} + +bool SSLClientSocket(Socket& socket, SslContext& ssl_context, const char *host, int port, bool nodelay, + dword *my_addr, int timeout, bool blocking) +{ + One data = new SSLSocketData(ssl_context); + if(!data->OpenClient(host, port, nodelay, my_addr, timeout, blocking)) + return false; + if(!data->Secure()) + return false; + socket.Attach(-data); + return true; +} + +bool SSLClientSocketUnsecured(Socket& socket, SslContext& ssl_context, const char *host, + int port, bool nodelay, dword *my_addr, int timeout, + bool is_blocking) +{ + One data = new SSLSocketData(ssl_context); + if(data->OpenClientUnsecured(host, port, nodelay, my_addr, timeout, is_blocking)) { + socket.Attach(-data); + return true; + } + return false; +} + +bool SSLSecureSocket(Socket& socket) +{ + SSLSocketData *sd = dynamic_cast(~socket.data); + if(!sd) + return false; + return sd->Secure(); +} +#endif + +END_UPP_NAMESPACE diff --git a/uppsrc/Core/Socket.cpp b/uppsrc/Core/Socket.cpp index c08019e38..b19184632 100644 --- a/uppsrc/Core/Socket.cpp +++ b/uppsrc/Core/Socket.cpp @@ -1,775 +1,780 @@ -#include "Core.h" - -#ifdef PLATFORM_WIN32 -#include -#include -#include -#endif - -#ifdef PLATFORM_POSIX -#include -#endif - -NAMESPACE_UPP - -#ifdef PLATFORM_WIN32 -#pragma comment(lib, "ws2_32.lib") -#endif - -#define LLOG(x) // DLOG("TCP " << x) - -IpAddrInfo::Entry IpAddrInfo::pool[COUNT]; - -RawMutex IpAddrInfoPoolMutex; - -void IpAddrInfo::EnterPool() -{ - IpAddrInfoPoolMutex.Enter(); -} - -void IpAddrInfo::LeavePool() -{ - IpAddrInfoPoolMutex.Leave(); -} - -int sGetAddrInfo(const char *host, const char *port, addrinfo **result) -{ - addrinfo hints; - memset(&hints, 0, sizeof(addrinfo)); - hints.ai_family = AF_UNSPEC; - hints.ai_socktype = SOCK_STREAM; - hints.ai_protocol = IPPROTO_TCP; - - return getaddrinfo(host, port, &hints, result); -} - -rawthread_t rawthread__ IpAddrInfo::Thread(void *ptr) -{ - Entry *entry = (Entry *)ptr; - EnterPool(); - if(entry->status == WORKING) { - char host[1025]; - char port[257]; - strcpy(host, entry->host); - strcpy(port, entry->port); - LeavePool(); - addrinfo *result; - if(sGetAddrInfo(host, port, &result) == 0 && result) { - EnterPool(); - if(entry->status == WORKING) { - entry->addr = result; - entry->status = RESOLVED; - } - else { - freeaddrinfo(entry->addr); - entry->status = EMPTY; - } - } - else { - EnterPool(); - if(entry->status == CANCELED) - entry->status = EMPTY; - else - entry->status = FAILED; - } - } - LeavePool(); - return 0; -} - -bool IpAddrInfo::Execute(const String& host, int port) -{ - Clear(); - entry = exe; - addrinfo *result; - entry->addr = sGetAddrInfo(~host, ~AsString(port), &result) == 0 ? result : NULL; - return entry->addr; -} - -void IpAddrInfo::Start() -{ - if(entry) - return; - EnterPool(); - for(int i = 0; i < COUNT; i++) { - Entry *e = pool + i; - if(e->status == EMPTY) { - entry = e; - e->addr = NULL; - if(host.GetCount() > 1024 || port.GetCount() > 256) - e->status = FAILED; - else { - e->status = WORKING; - e->host = host; - e->port = port; - StartRawThread(&IpAddrInfo::Thread, e); - } - break; - } - } - LeavePool(); -} - -void IpAddrInfo::Start(const String& host_, int port_) -{ - Clear(); - port = AsString(port_); - host = host_; - Start(); -} - -bool IpAddrInfo::InProgress() -{ - if(!entry) { - Start(); - return true; - } - EnterPool(); - int s = entry->status; - LeavePool(); - return s == WORKING; -} - -addrinfo *IpAddrInfo::GetResult() -{ - EnterPool(); - addrinfo *ai = entry ? entry->addr : NULL; - LeavePool(); - return ai; -} - -void IpAddrInfo::Clear() -{ - EnterPool(); - if(entry) { - if(entry->status == RESOLVED && entry->addr) - freeaddrinfo(entry->addr); - if(entry->status == WORKING) - entry->status = CANCELED; - entry->status = EMPTY; - entry = NULL; - } - LeavePool(); -} - -IpAddrInfo::IpAddrInfo() -{ - TcpSocket::Init(); - entry = NULL; -} - -#ifdef PLATFORM_POSIX - -#define SOCKERR(x) x - -const char *TcpSocketErrorDesc(int code) -{ - return strerror(code); -} - -int TcpSocket::GetErrorCode() -{ - return errno; -} - -#else - -#define SOCKERR(x) WSA##x - -const char *TcpSocketErrorDesc(int code) -{ - static Tuple2 err[] = { - { WSAEINTR, "Interrupted function call." }, - { WSAEACCES, "Permission denied." }, - { WSAEFAULT, "Bad address." }, - { WSAEINVAL, "Invalid argument." }, - { WSAEMFILE, "Too many open files." }, - { WSAEWOULDBLOCK, "Resource temporarily unavailable." }, - { WSAEINPROGRESS, "Operation now in progress." }, - { WSAEALREADY, "Operation already in progress." }, - { WSAENOTSOCK, "TcpSocket operation on nonsocket." }, - { WSAEDESTADDRREQ, "Destination address required." }, - { WSAEMSGSIZE, "Message too long." }, - { WSAEPROTOTYPE, "Protocol wrong type for socket." }, - { WSAENOPROTOOPT, "Bad protocol option." }, - { WSAEPROTONOSUPPORT, "Protocol not supported." }, - { WSAESOCKTNOSUPPORT, "TcpSocket type not supported." }, - { WSAEOPNOTSUPP, "Operation not supported." }, - { WSAEPFNOSUPPORT, "Protocol family not supported." }, - { WSAEAFNOSUPPORT, "Address family not supported by protocol family." }, - { WSAEADDRINUSE, "Address already in use." }, - { WSAEADDRNOTAVAIL, "Cannot assign requested address." }, - { WSAENETDOWN, "Network is down." }, - { WSAENETUNREACH, "Network is unreachable." }, - { WSAENETRESET, "Network dropped connection on reset." }, - { WSAECONNABORTED, "Software caused connection abort." }, - { WSAECONNRESET, "Connection reset by peer." }, - { WSAENOBUFS, "No buffer space available." }, - { WSAEISCONN, "TcpSocket is already connected." }, - { WSAENOTCONN, "TcpSocket is not connected." }, - { WSAESHUTDOWN, "Cannot send after socket shutdown." }, - { WSAETIMEDOUT, "Connection timed out." }, - { WSAECONNREFUSED, "Connection refused." }, - { WSAEHOSTDOWN, "Host is down." }, - { WSAEHOSTUNREACH, "No route to host." }, - { WSAEPROCLIM, "Too many processes." }, - { WSASYSNOTREADY, "Network subsystem is unavailable." }, - { WSAVERNOTSUPPORTED, "Winsock.dll version out of range." }, - { WSANOTINITIALISED, "Successful WSAStartup not yet performed." }, - { WSAEDISCON, "Graceful shutdown in progress." }, - { WSATYPE_NOT_FOUND, "Class type not found." }, - { WSAHOST_NOT_FOUND, "Host not found." }, - { WSATRY_AGAIN, "Nonauthoritative host not found." }, - { WSANO_RECOVERY, "This is a nonrecoverable error." }, - { WSANO_DATA, "Valid name, no data record of requested type." }, - { WSASYSCALLFAILURE, "System call failure." }, - }; - const Tuple2 *x = FindTuple(err, __countof(err), code); - return x ? x->b : "Unknown error code."; -} - -int TcpSocket::GetErrorCode() -{ - return WSAGetLastError(); -} - -#endif - -void TcpSocketInit() -{ -#if defined(PLATFORM_WIN32) - ONCELOCK { - WSADATA wsadata; - WSAStartup(MAKEWORD(2, 2), &wsadata); - } -#endif -} - -void TcpSocket::Init() -{ - TcpSocketInit(); -} - -void TcpSocket::Reset() -{ - is_eof = false; - socket = INVALID_SOCKET; - ipv6 = false; - ptr = end = buffer; - is_error = false; - is_abort = false; - mode = NONE; - ssl.Clear(); -} - -TcpSocket::TcpSocket() -{ - ClearError(); - Reset(); - timeout = Null; - waitstep = 20; -} - -bool TcpSocket::Open(int family, int type, int protocol) -{ - Init(); - Close(); - ClearError(); - if((socket = ::socket(family, type, protocol)) == INVALID_SOCKET) - return false; - LLOG("TcpSocket::Data::Open() -> " << (int)socket); -#ifdef PLATFORM_WIN32 - u_long arg = 1; - if(ioctlsocket(socket, FIONBIO, &arg)) - SetSockError("ioctlsocket(FIO[N]BIO)"); -#else - if(fcntl(socket, F_SETFL, (fcntl(socket, F_GETFL, 0) | O_NONBLOCK))) - SetSockError("fcntl(O_[NON]BLOCK)"); -#endif - return true; -} - -bool TcpSocket::Listen(int port, int listen_count, bool ipv6_, bool reuse) -{ - Init(); - - ipv6 = ipv6_; - if(!Open(ipv6 ? AF_INET6 : AF_INET, SOCK_STREAM, 0)) - return false; - sockaddr_in sin; -#ifdef PLATFORM_WIN32 - SOCKADDR_IN6 sin6; - if(ipv6 && IsWinVista()) -#else - sockaddr_in6 sin6; - if(ipv6) -#endif - { - Zero(sin6); - sin.sin_family = AF_INET6; - sin.sin_port = htons(port); - sin.sin_addr.s_addr = htonl(INADDR_ANY); - } - else { - Zero(sin); - sin.sin_family = AF_INET; - sin.sin_port = htons(port); - sin.sin_addr.s_addr = htonl(INADDR_ANY); - } - if(reuse) { - int optval = 1; - setsockopt(socket, SOL_SOCKET, SO_REUSEADDR, (const char *)&optval, sizeof(optval)); - } - if(bind(socket, ipv6 ? (const sockaddr *)&sin6 : (const sockaddr *)&sin, - ipv6 ? sizeof(sin6) : sizeof(sin))) { - SetSockError(Format("bind(port=%d)", port)); - return false; - } - if(listen(socket, listen_count)) { - SetSockError(Format("listen(port=%d, count=%d)", port, listen_count)); - return false; - } - return true; -} - -bool TcpSocket::Accept(TcpSocket& ls) -{ - Close(); - if(timeout && !ls.WaitRead()) - return false; - if(!Open(ls.ipv6 ? AF_INET6 : AF_INET, SOCK_STREAM, 0)) - return false; - socket = accept(ls.GetSOCKET(), NULL, NULL); - if(socket == INVALID_SOCKET) { - SetSockError("accept"); - return false; - } - mode = ACCEPT; - return true; -} - -String TcpSocket::GetPeerAddr() const -{ - if(!IsOpen()) - return Null; - sockaddr_in addr; - socklen_t l = sizeof(addr); - if(getpeername(socket, (sockaddr *)&addr, &l) != 0) - return Null; - if(l > sizeof(addr)) - return Null; -#ifdef PLATFORM_WIN32 - return inet_ntoa(addr.sin_addr); -#else - char h[200]; - return inet_ntop(AF_INET, &addr.sin_addr, h, 200); -#endif -} - -void TcpSocket::NoDelay() -{ - ASSERT(IsOpen()); - int __true = 1; - LLOG("NoDelay(" << (int)socket << ")"); - if(setsockopt(socket, IPPROTO_TCP, TCP_NODELAY, (const char *)&__true, sizeof(__true))) - SetSockError("setsockopt(TCP_NODELAY)"); -} - -void TcpSocket::Linger(int msecs) -{ - ASSERT(IsOpen()); - linger ls; - ls.l_onoff = !IsNull(msecs) ? 1 : 0; - ls.l_linger = !IsNull(msecs) ? (msecs + 999) / 1000 : 0; - if(setsockopt(socket, SOL_SOCKET, SO_LINGER, - reinterpret_cast(&ls), sizeof(ls))) - SetSockError("setsockopt(SO_LINGER)"); -} - -void TcpSocket::Attach(SOCKET s) -{ - Close(); - socket = s; -} - -bool TcpSocket::RawConnect(addrinfo *rp) -{ - for(;;) { - if(Open(rp->ai_family, rp->ai_socktype, rp->ai_protocol)) { - if(connect(socket, rp->ai_addr, rp->ai_addrlen) == 0 || - GetErrorCode() == SOCKERR(EINPROGRESS) || GetErrorCode() == SOCKERR(EWOULDBLOCK)) - break; - Close(); - } - rp = rp->ai_next; - if(!rp) { - SetSockError("connect has failed"); - return false; - } - } - mode = CONNECT; - return true; -} - - -bool TcpSocket::Connect(IpAddrInfo& info) -{ - LLOG("TCP Connect addrinfo"); - Init(); - addrinfo *result = info.GetResult(); - return result && RawConnect(result); -} - -bool TcpSocket::Connect(const char *host, int port) -{ - LLOG("TCP Connect(" << host << ':' << port << ')'); - - Init(); - IpAddrInfo info; - if(!info.Execute(host, port)) { - SetSockError(Format("getaddrinfo(%s) failed", host)); - return false; - } - return Connect(info); -} - -void TcpSocket::RawClose() -{ - LLOG("TCP close " << (int)socket); - if(socket != INVALID_SOCKET) { - int res; -#if defined(PLATFORM_WIN32) - res = closesocket(socket); -#elif defined(PLATFORM_POSIX) - res = close(socket); -#else - #error Unsupported platform -#endif - if(res && !IsError()) - SetSockError("close"); - socket = INVALID_SOCKET; - } -} - -void TcpSocket::Close() -{ - if(ssl) - ssl->Close(*this); - else - RawClose(); - ssl.Clear(); -} - -bool TcpSocket::WouldBlock() -{ - int c = GetErrorCode(); -#ifdef PLATFORM_POSIX - return c == SOCKERR(EWOULDBLOCK) || c == SOCKERR(EAGAIN); -#endif -#ifdef PLATFORM_WIN32 - return c == SOCKERR(EWOULDBLOCK) || c == SOCKERR(ENOTCONN); -#endif -} - -int TcpSocket::RawSend(const void *buf, int amount) -{ - int res = send(socket, (const char *)buf, amount, 0); - if(res < 0 && WouldBlock()) - res = 0; - else - if(res == 0 || res < 0) - SetSockError("send"); - return res; -} - -int TcpSocket::Send(const void *buf, int amount) -{ - return ssl ? ssl->Send(*this, buf, amount) : RawSend(buf, amount); -} - -void TcpSocket::Shutdown() -{ - ASSERT(IsOpen()); - if(shutdown(socket, SD_SEND)) - SetSockError("shutdown(SD_SEND)"); -} - -String TcpSocket::GetHostName() -{ - Init(); - char buffer[256]; - gethostname(buffer, __countof(buffer)); - return buffer; -} - -bool TcpSocket::RawWait(dword flags) -{ - LLOG("Wait(" << timeout << ", " << flags << ")"); - if((flags & WAIT_READ) && ptr != end) - return true; - int end_time = msecs() + timeout; - if(socket == INVALID_SOCKET) - return false; - for(;;) { - if(IsError() || IsAbort()) - return false; - int to = end_time - msecs(); - if(WhenWait) - to = waitstep; - timeval *tvalp = NULL; - timeval tval; - if(!IsNull(timeout) || WhenWait) { - to = max(to, 0); - tval.tv_sec = to / 1000; - tval.tv_usec = 1000 * (to % 1000); - tvalp = &tval; - } - fd_set fdset[1]; - FD_ZERO(fdset); - FD_SET(socket, fdset); - int avail = select((int)socket + 1, - flags & WAIT_READ ? fdset : NULL, - flags & WAIT_WRITE ? fdset : NULL, - flags & WAIT_EXCEPTION ? fdset : NULL, tvalp); - LLOG("Wait select avail: " << avail); - if(avail < 0) { - SetSockError("wait"); - return false; - } - if(avail > 0) - return true; - if(to <= 0 && timeout) { - return false; - } - WhenWait(); - if(timeout == 0) - return false; - } -} - -bool TcpSocket::Wait(dword flags) -{ - return ssl ? ssl->Wait(*this, flags) : RawWait(flags); -} - -int TcpSocket::Put(const char *s, int length) -{ - LLOG("Put " << socket << ": " << length); - ASSERT(IsOpen()); - if(length < 0 && s) - length = (int)strlen(s); - if(!s || length <= 0 || IsError() || IsAbort()) - return 0; - done = 0; - bool peek = false; - while(done < length) { - if(peek && !WaitWrite()) - return done; - peek = false; - int count = Send(s + done, length - done); - if(IsError() || timeout == 0 && count == 0 && peek) - return done; - if(count > 0) - done += count; - else - peek = true; - } - LLOG("//Put() -> " << done); - return done; -} - -int TcpSocket::RawRecv(void *buf, int amount) -{ - int res = recv(socket, (char *)buf, amount, 0); - if(res == 0) - is_eof = true; - else - if(res < 0 && WouldBlock()) - res = 0; - else - if(res < 0) - SetSockError("recv"); - LLOG("recv(" << socket << "): " << res << " bytes: " - << AsCString((char *)buf, (char *)buf + min(res, 16)) - << (res ? "" : IsEof() ? ", EOF" : ", WOULDBLOCK")); - return res; -} - -int TcpSocket::Recv(void *buffer, int maxlen) -{ - return ssl ? ssl->Recv(*this, buffer, maxlen) : RawRecv(buffer, maxlen); -} - -void TcpSocket::ReadBuffer() -{ - ptr = end = buffer; - if(WaitRead()) - end = buffer + Recv(buffer, BUFFERSIZE); -} - -int TcpSocket::Get_() -{ - if(!IsOpen() || IsError() || IsEof() || IsAbort()) - return -1; - ReadBuffer(); - return ptr < end ? *ptr++ : -1; -} - -int TcpSocket::Peek_() -{ - if(!IsOpen() || IsError() || IsEof() || IsAbort()) - return -1; - ReadBuffer(); - return ptr < end ? *ptr : -1; -} - -int TcpSocket::Get(void *buffer, int count) -{ - LLOG("Get " << count); - - if(!IsOpen() || IsError() || IsEof() || IsAbort()) - return 0; - - String out; - int l = end - ptr; - done = 0; - if(l > 0) - if(l < count) { - memcpy(buffer, ptr, l); - done += l; - ptr = end; - } - else { - memcpy(buffer, ptr, count); - ptr += count; - return count; - } - while(done < count && !IsError() && !IsEof()) { - if(!WaitRead()) - break; - int part = Recv((char *)buffer + done, count - done); - if(part > 0) - done += part; - if(timeout == 0) - break; - } - return done; -} - -String TcpSocket::Get(int count) -{ - if(count == 0) - return Null; - StringBuffer out(count); - int done = Get(out, count); - if(!done && IsEof()) - return String::GetVoid(); - out.SetLength(done); - return out; -} - -String TcpSocket::GetLine(int maxlen) -{ - String ln; - for(;;) { - int c = Peek(); - if(c < 0) - return String::GetVoid(); - Get(); - if(c == '\n') - return ln; - if(c != '\r') - ln.Cat(c); - } -} - -void TcpSocket::SetSockError(const char *context, const char *errdesc) -{ - String err; - errorcode = GetErrorCode(); - if(socket != INVALID_SOCKET) - err << "socket(" << (int)socket << ") / "; - err << context << ": " << errdesc; - errordesc = err; - is_error = true; -} - -void TcpSocket::SetSockError(const char *context) -{ - SetSockError(context, TcpSocketErrorDesc(GetErrorCode())); -} - -TcpSocket::SSL *(*TcpSocket::CreateSSL)(); - -bool TcpSocket::StartSSL() -{ - if(!CreateSSL) { - errorcode = -1; - errordesc = "Missing SSL support (Core/SSL)"; - return false; - } - if(!IsOpen() || mode == NONE) { - errorcode = -1; - errordesc = "Socket not open or listening"; - return false; - } - ssl = (*CreateSSL)(); - if(!ssl->Start(*this)) { - ssl.Clear(); - return false; - } - return true; -} - -int SocketWaitEvent::Wait(int timeout) -{ - FD_ZERO(read); - FD_ZERO(write); - FD_ZERO(exception); - int maxindex = -1; - for(int i = 0; i < socket.GetCount(); i++) { - const Tuple2& s = socket[i]; - if(s.a >= 0) { - const Tuple2& s = socket[i]; - if(s.b & WAIT_READ) - FD_SET(s.a, read); - if(s.b & WAIT_WRITE) - FD_SET(s.a, write); - if(s.b & WAIT_EXCEPTION) - FD_SET(s.a, exception); - maxindex = max(s.a, maxindex); - } - } - timeval *tvalp = NULL; - timeval tval; - if(!IsNull(timeout)) { - tval.tv_sec = timeout / 1000; - tval.tv_usec = 1000 * (timeout % 1000); - tvalp = &tval; - } - return select(maxindex + 1, read, write, exception, tvalp); -} - -dword SocketWaitEvent::Get(int i) const -{ - int s = socket[i].a; - if(s < 0) - return 0; - dword events = 0; - if(FD_ISSET(s, read)) - events |= WAIT_READ; - if(FD_ISSET(s, write)) - events |= WAIT_WRITE; - if(FD_ISSET(s, exception)) - events |= WAIT_EXCEPTION; - return events; -} - -SocketWaitEvent::SocketWaitEvent() -{ - FD_ZERO(read); - FD_ZERO(write); - FD_ZERO(exception); -} - -END_UPP_NAMESPACE +#include "Core.h" + +#ifdef PLATFORM_WIN32 +#include +#include +#include +#endif + +#ifdef PLATFORM_POSIX +#include +#endif + +NAMESPACE_UPP + +#ifdef PLATFORM_WIN32 +#pragma comment(lib, "ws2_32.lib") +#endif + +#define LLOG(x) // DLOG("TCP " << x) + +IpAddrInfo::Entry IpAddrInfo::pool[COUNT]; + +RawMutex IpAddrInfoPoolMutex; + +void IpAddrInfo::EnterPool() +{ + IpAddrInfoPoolMutex.Enter(); +} + +void IpAddrInfo::LeavePool() +{ + IpAddrInfoPoolMutex.Leave(); +} + +int sGetAddrInfo(const char *host, const char *port, addrinfo **result) +{ + addrinfo hints; + memset(&hints, 0, sizeof(addrinfo)); + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + hints.ai_protocol = IPPROTO_TCP; + + return getaddrinfo(host, port, &hints, result); +} + +rawthread_t rawthread__ IpAddrInfo::Thread(void *ptr) +{ + Entry *entry = (Entry *)ptr; + EnterPool(); + if(entry->status == WORKING) { + char host[1025]; + char port[257]; + strcpy(host, entry->host); + strcpy(port, entry->port); + LeavePool(); + addrinfo *result; + if(sGetAddrInfo(host, port, &result) == 0 && result) { + EnterPool(); + if(entry->status == WORKING) { + entry->addr = result; + entry->status = RESOLVED; + } + else { + freeaddrinfo(entry->addr); + entry->status = EMPTY; + } + } + else { + EnterPool(); + if(entry->status == CANCELED) + entry->status = EMPTY; + else + entry->status = FAILED; + } + } + LeavePool(); + return 0; +} + +bool IpAddrInfo::Execute(const String& host, int port) +{ + Clear(); + entry = exe; + addrinfo *result; + entry->addr = sGetAddrInfo(~host, ~AsString(port), &result) == 0 ? result : NULL; + return entry->addr; +} + +void IpAddrInfo::Start() +{ + if(entry) + return; + EnterPool(); + for(int i = 0; i < COUNT; i++) { + Entry *e = pool + i; + if(e->status == EMPTY) { + entry = e; + e->addr = NULL; + if(host.GetCount() > 1024 || port.GetCount() > 256) + e->status = FAILED; + else { + e->status = WORKING; + e->host = host; + e->port = port; + StartRawThread(&IpAddrInfo::Thread, e); + } + break; + } + } + LeavePool(); +} + +void IpAddrInfo::Start(const String& host_, int port_) +{ + Clear(); + port = AsString(port_); + host = host_; + Start(); +} + +bool IpAddrInfo::InProgress() +{ + if(!entry) { + Start(); + return true; + } + EnterPool(); + int s = entry->status; + LeavePool(); + return s == WORKING; +} + +addrinfo *IpAddrInfo::GetResult() +{ + EnterPool(); + addrinfo *ai = entry ? entry->addr : NULL; + LeavePool(); + return ai; +} + +void IpAddrInfo::Clear() +{ + EnterPool(); + if(entry) { + if(entry->status == RESOLVED && entry->addr) + freeaddrinfo(entry->addr); + if(entry->status == WORKING) + entry->status = CANCELED; + entry->status = EMPTY; + entry = NULL; + } + LeavePool(); +} + +IpAddrInfo::IpAddrInfo() +{ + TcpSocket::Init(); + entry = NULL; +} + +#ifdef PLATFORM_POSIX + +#define SOCKERR(x) x + +const char *TcpSocketErrorDesc(int code) +{ + return strerror(code); +} + +int TcpSocket::GetErrorCode() +{ + return errno; +} + +#else + +#define SOCKERR(x) WSA##x + +const char *TcpSocketErrorDesc(int code) +{ + static Tuple2 err[] = { + { WSAEINTR, "Interrupted function call." }, + { WSAEACCES, "Permission denied." }, + { WSAEFAULT, "Bad address." }, + { WSAEINVAL, "Invalid argument." }, + { WSAEMFILE, "Too many open files." }, + { WSAEWOULDBLOCK, "Resource temporarily unavailable." }, + { WSAEINPROGRESS, "Operation now in progress." }, + { WSAEALREADY, "Operation already in progress." }, + { WSAENOTSOCK, "TcpSocket operation on nonsocket." }, + { WSAEDESTADDRREQ, "Destination address required." }, + { WSAEMSGSIZE, "Message too long." }, + { WSAEPROTOTYPE, "Protocol wrong type for socket." }, + { WSAENOPROTOOPT, "Bad protocol option." }, + { WSAEPROTONOSUPPORT, "Protocol not supported." }, + { WSAESOCKTNOSUPPORT, "TcpSocket type not supported." }, + { WSAEOPNOTSUPP, "Operation not supported." }, + { WSAEPFNOSUPPORT, "Protocol family not supported." }, + { WSAEAFNOSUPPORT, "Address family not supported by protocol family." }, + { WSAEADDRINUSE, "Address already in use." }, + { WSAEADDRNOTAVAIL, "Cannot assign requested address." }, + { WSAENETDOWN, "Network is down." }, + { WSAENETUNREACH, "Network is unreachable." }, + { WSAENETRESET, "Network dropped connection on reset." }, + { WSAECONNABORTED, "Software caused connection abort." }, + { WSAECONNRESET, "Connection reset by peer." }, + { WSAENOBUFS, "No buffer space available." }, + { WSAEISCONN, "TcpSocket is already connected." }, + { WSAENOTCONN, "TcpSocket is not connected." }, + { WSAESHUTDOWN, "Cannot send after socket shutdown." }, + { WSAETIMEDOUT, "Connection timed out." }, + { WSAECONNREFUSED, "Connection refused." }, + { WSAEHOSTDOWN, "Host is down." }, + { WSAEHOSTUNREACH, "No route to host." }, + { WSAEPROCLIM, "Too many processes." }, + { WSASYSNOTREADY, "Network subsystem is unavailable." }, + { WSAVERNOTSUPPORTED, "Winsock.dll version out of range." }, + { WSANOTINITIALISED, "Successful WSAStartup not yet performed." }, + { WSAEDISCON, "Graceful shutdown in progress." }, + { WSATYPE_NOT_FOUND, "Class type not found." }, + { WSAHOST_NOT_FOUND, "Host not found." }, + { WSATRY_AGAIN, "Nonauthoritative host not found." }, + { WSANO_RECOVERY, "This is a nonrecoverable error." }, + { WSANO_DATA, "Valid name, no data record of requested type." }, + { WSASYSCALLFAILURE, "System call failure." }, + }; + const Tuple2 *x = FindTuple(err, __countof(err), code); + return x ? x->b : "Unknown error code."; +} + +int TcpSocket::GetErrorCode() +{ + return WSAGetLastError(); +} + +#endif + +void TcpSocketInit() +{ +#if defined(PLATFORM_WIN32) + ONCELOCK { + WSADATA wsadata; + WSAStartup(MAKEWORD(2, 2), &wsadata); + } +#endif +} + +void TcpSocket::Init() +{ + TcpSocketInit(); +} + +void TcpSocket::Reset() +{ + is_eof = false; + socket = INVALID_SOCKET; + ipv6 = false; + ptr = end = buffer; + is_error = false; + is_abort = false; + mode = NONE; + ssl.Clear(); +} + +TcpSocket::TcpSocket() +{ + ClearError(); + Reset(); + timeout = Null; + waitstep = 20; +} + +bool TcpSocket::Open(int family, int type, int protocol) +{ + Init(); + Close(); + ClearError(); + if((socket = ::socket(family, type, protocol)) == INVALID_SOCKET) + return false; + LLOG("TcpSocket::Data::Open() -> " << (int)socket); +#ifdef PLATFORM_WIN32 + u_long arg = 1; + if(ioctlsocket(socket, FIONBIO, &arg)) + SetSockError("ioctlsocket(FIO[N]BIO)"); +#else + if(fcntl(socket, F_SETFL, (fcntl(socket, F_GETFL, 0) | O_NONBLOCK))) + SetSockError("fcntl(O_[NON]BLOCK)"); +#endif + return true; +} + +bool TcpSocket::Listen(int port, int listen_count, bool ipv6_, bool reuse) +{ + Init(); + + ipv6 = ipv6_; + if(!Open(ipv6 ? AF_INET6 : AF_INET, SOCK_STREAM, 0)) + return false; + sockaddr_in sin; +#ifdef PLATFORM_WIN32 + SOCKADDR_IN6 sin6; + if(ipv6 && IsWinVista()) +#else + sockaddr_in6 sin6; + if(ipv6) +#endif + { + Zero(sin6); + sin.sin_family = AF_INET6; + sin.sin_port = htons(port); + sin.sin_addr.s_addr = htonl(INADDR_ANY); + } + else { + Zero(sin); + sin.sin_family = AF_INET; + sin.sin_port = htons(port); + sin.sin_addr.s_addr = htonl(INADDR_ANY); + } + if(reuse) { + int optval = 1; + setsockopt(socket, SOL_SOCKET, SO_REUSEADDR, (const char *)&optval, sizeof(optval)); + } + if(bind(socket, ipv6 ? (const sockaddr *)&sin6 : (const sockaddr *)&sin, + ipv6 ? sizeof(sin6) : sizeof(sin))) { + SetSockError(Format("bind(port=%d)", port)); + return false; + } + if(listen(socket, listen_count)) { + SetSockError(Format("listen(port=%d, count=%d)", port, listen_count)); + return false; + } + return true; +} + +bool TcpSocket::Accept(TcpSocket& ls) +{ + Close(); + if(timeout && !ls.WaitRead()) + return false; + if(!Open(ls.ipv6 ? AF_INET6 : AF_INET, SOCK_STREAM, 0)) + return false; + socket = accept(ls.GetSOCKET(), NULL, NULL); + if(socket == INVALID_SOCKET) { + SetSockError("accept"); + return false; + } + mode = ACCEPT; + return true; +} + +String TcpSocket::GetPeerAddr() const +{ + if(!IsOpen()) + return Null; + sockaddr_in addr; + socklen_t l = sizeof(addr); + if(getpeername(socket, (sockaddr *)&addr, &l) != 0) + return Null; + if(l > sizeof(addr)) + return Null; +#ifdef PLATFORM_WIN32 + return inet_ntoa(addr.sin_addr); +#else + char h[200]; + return inet_ntop(AF_INET, &addr.sin_addr, h, 200); +#endif +} + +void TcpSocket::NoDelay() +{ + ASSERT(IsOpen()); + int __true = 1; + LLOG("NoDelay(" << (int)socket << ")"); + if(setsockopt(socket, IPPROTO_TCP, TCP_NODELAY, (const char *)&__true, sizeof(__true))) + SetSockError("setsockopt(TCP_NODELAY)"); +} + +void TcpSocket::Linger(int msecs) +{ + ASSERT(IsOpen()); + linger ls; + ls.l_onoff = !IsNull(msecs) ? 1 : 0; + ls.l_linger = !IsNull(msecs) ? (msecs + 999) / 1000 : 0; + if(setsockopt(socket, SOL_SOCKET, SO_LINGER, + reinterpret_cast(&ls), sizeof(ls))) + SetSockError("setsockopt(SO_LINGER)"); +} + +void TcpSocket::Attach(SOCKET s) +{ + Close(); + socket = s; +} + +bool TcpSocket::RawConnect(addrinfo *rp) +{ + for(;;) { + if(Open(rp->ai_family, rp->ai_socktype, rp->ai_protocol)) { + if(connect(socket, rp->ai_addr, rp->ai_addrlen) == 0 || + GetErrorCode() == SOCKERR(EINPROGRESS) || GetErrorCode() == SOCKERR(EWOULDBLOCK)) + break; + Close(); + } + rp = rp->ai_next; + if(!rp) { + SetSockError("connect has failed"); + return false; + } + } + mode = CONNECT; + return true; +} + + +bool TcpSocket::Connect(IpAddrInfo& info) +{ + LLOG("TCP Connect addrinfo"); + Init(); + addrinfo *result = info.GetResult(); + return result && RawConnect(result); +} + +bool TcpSocket::Connect(const char *host, int port) +{ + LLOG("TCP Connect(" << host << ':' << port << ')'); + + Init(); + IpAddrInfo info; + if(!info.Execute(host, port)) { + SetSockError(Format("getaddrinfo(%s) failed", host)); + return false; + } + return Connect(info); +} + +void TcpSocket::RawClose() +{ + LLOG("TCP close " << (int)socket); + if(socket != INVALID_SOCKET) { + int res; +#if defined(PLATFORM_WIN32) + res = closesocket(socket); +#elif defined(PLATFORM_POSIX) + res = close(socket); +#else + #error Unsupported platform +#endif + if(res && !IsError()) + SetSockError("close"); + socket = INVALID_SOCKET; + } +} + +void TcpSocket::Close() +{ + if(ssl) + ssl->Close(); + else + RawClose(); + ssl.Clear(); +} + +bool TcpSocket::WouldBlock() +{ + int c = GetErrorCode(); +#ifdef PLATFORM_POSIX + return c == SOCKERR(EWOULDBLOCK) || c == SOCKERR(EAGAIN); +#endif +#ifdef PLATFORM_WIN32 + return c == SOCKERR(EWOULDBLOCK) || c == SOCKERR(ENOTCONN); +#endif +} + +int TcpSocket::RawSend(const void *buf, int amount) +{ + int res = send(socket, (const char *)buf, amount, 0); + if(res < 0 && WouldBlock()) + res = 0; + else + if(res == 0 || res < 0) + SetSockError("send"); + return res; +} + +int TcpSocket::Send(const void *buf, int amount) +{ + return ssl ? ssl->Send(buf, amount) : RawSend(buf, amount); +} + +void TcpSocket::Shutdown() +{ + ASSERT(IsOpen()); + if(shutdown(socket, SD_SEND)) + SetSockError("shutdown(SD_SEND)"); +} + +String TcpSocket::GetHostName() +{ + Init(); + char buffer[256]; + gethostname(buffer, __countof(buffer)); + return buffer; +} + +bool TcpSocket::RawWait(dword flags) +{ + LLOG("Wait(" << timeout << ", " << flags << ")"); + if((flags & WAIT_READ) && ptr != end) + return true; + int end_time = msecs() + timeout; + if(socket == INVALID_SOCKET) + return false; + for(;;) { + if(IsError() || IsAbort()) + return false; + int to = end_time - msecs(); + if(WhenWait) + to = waitstep; + timeval *tvalp = NULL; + timeval tval; + if(!IsNull(timeout) || WhenWait) { + to = max(to, 0); + tval.tv_sec = to / 1000; + tval.tv_usec = 1000 * (to % 1000); + tvalp = &tval; + } + fd_set fdset[1]; + FD_ZERO(fdset); + FD_SET(socket, fdset); + int avail = select((int)socket + 1, + flags & WAIT_READ ? fdset : NULL, + flags & WAIT_WRITE ? fdset : NULL, + flags & WAIT_EXCEPTION ? fdset : NULL, tvalp); + LLOG("Wait select avail: " << avail); + if(avail < 0) { + SetSockError("wait"); + return false; + } + if(avail > 0) + return true; + if(to <= 0 && timeout) { + return false; + } + WhenWait(); + if(timeout == 0) + return false; + } +} + +bool TcpSocket::Wait(dword flags) +{ + return ssl ? ssl->Wait(flags) : RawWait(flags); +} + +int TcpSocket::Put(const char *s, int length) +{ + LLOG("Put " << socket << ": " << length); + ASSERT(IsOpen()); + if(length < 0 && s) + length = (int)strlen(s); + if(!s || length <= 0 || IsError() || IsAbort()) + return 0; + done = 0; + bool peek = false; + while(done < length) { + if(peek && !WaitWrite()) + return done; + peek = false; + int count = Send(s + done, length - done); + if(IsError() || timeout == 0 && count == 0 && peek) + return done; + if(count > 0) + done += count; + else + peek = true; + } + LLOG("//Put() -> " << done); + return done; +} + +int TcpSocket::RawRecv(void *buf, int amount) +{ + int res = recv(socket, (char *)buf, amount, 0); + if(res == 0) + is_eof = true; + else + if(res < 0 && WouldBlock()) + res = 0; + else + if(res < 0) + SetSockError("recv"); + LLOG("recv(" << socket << "): " << res << " bytes: " + << AsCString((char *)buf, (char *)buf + min(res, 16)) + << (res ? "" : IsEof() ? ", EOF" : ", WOULDBLOCK")); + return res; +} + +int TcpSocket::Recv(void *buffer, int maxlen) +{ + return ssl ? ssl->Recv(buffer, maxlen) : RawRecv(buffer, maxlen); +} + +void TcpSocket::ReadBuffer() +{ + ptr = end = buffer; + if(WaitRead()) + end = buffer + Recv(buffer, BUFFERSIZE); +} + +int TcpSocket::Get_() +{ + if(!IsOpen() || IsError() || IsEof() || IsAbort()) + return -1; + ReadBuffer(); + return ptr < end ? *ptr++ : -1; +} + +int TcpSocket::Peek_() +{ + if(!IsOpen() || IsError() || IsEof() || IsAbort()) + return -1; + ReadBuffer(); + return ptr < end ? *ptr : -1; +} + +int TcpSocket::Get(void *buffer, int count) +{ + LLOG("Get " << count); + + if(!IsOpen() || IsError() || IsEof() || IsAbort()) + return 0; + + String out; + int l = end - ptr; + done = 0; + if(l > 0) + if(l < count) { + memcpy(buffer, ptr, l); + done += l; + ptr = end; + } + else { + memcpy(buffer, ptr, count); + ptr += count; + return count; + } + while(done < count && !IsError() && !IsEof()) { + if(!WaitRead()) + break; + int part = Recv((char *)buffer + done, count - done); + if(part > 0) + done += part; + if(timeout == 0) + break; + } + return done; +} + +String TcpSocket::Get(int count) +{ + if(count == 0) + return Null; + StringBuffer out(count); + int done = Get(out, count); + if(!done && IsEof()) + return String::GetVoid(); + out.SetLength(done); + return out; +} + +String TcpSocket::GetLine(int maxlen) +{ + String ln; + for(;;) { + int c = Peek(); + if(c < 0) + return String::GetVoid(); + Get(); + if(c == '\n') + return ln; + if(c != '\r') + ln.Cat(c); + } +} + +void TcpSocket::SetSockError(const char *context, int code, const char *errdesc) +{ + errorcode = code; + errordesc.Clear(); + if(socket != INVALID_SOCKET) + errordesc << "socket(" << (int)socket << ") / "; + errordesc << context << ": " << errdesc; + is_error = true; +} + +void TcpSocket::SetSockError(const char *context, const char *errdesc) +{ + SetSockError(context, GetErrorCode(), errdesc); +} + +void TcpSocket::SetSockError(const char *context) +{ + SetSockError(context, TcpSocketErrorDesc(GetErrorCode())); +} + +TcpSocket::SSL *(*TcpSocket::CreateSSL)(TcpSocket& socket); + +bool TcpSocket::StartSSL() +{ + ASSERT(IsOpen()); + if(!CreateSSL) { + errorcode = -1; + errordesc = "Missing SSL support (Core/SSL)"; + return false; + } + if(!IsOpen() || mode == NONE) { + errorcode = -1; + errordesc = "Socket not open or listening"; + return false; + } + ssl = (*CreateSSL)(*this); + if(!ssl->Start()) { + ssl.Clear(); + return false; + } + return true; +} + +int SocketWaitEvent::Wait(int timeout) +{ + FD_ZERO(read); + FD_ZERO(write); + FD_ZERO(exception); + int maxindex = -1; + for(int i = 0; i < socket.GetCount(); i++) { + const Tuple2& s = socket[i]; + if(s.a >= 0) { + const Tuple2& s = socket[i]; + if(s.b & WAIT_READ) + FD_SET(s.a, read); + if(s.b & WAIT_WRITE) + FD_SET(s.a, write); + if(s.b & WAIT_EXCEPTION) + FD_SET(s.a, exception); + maxindex = max(s.a, maxindex); + } + } + timeval *tvalp = NULL; + timeval tval; + if(!IsNull(timeout)) { + tval.tv_sec = timeout / 1000; + tval.tv_usec = 1000 * (timeout % 1000); + tvalp = &tval; + } + return select(maxindex + 1, read, write, exception, tvalp); +} + +dword SocketWaitEvent::Get(int i) const +{ + int s = socket[i].a; + if(s < 0) + return 0; + dword events = 0; + if(FD_ISSET(s, read)) + events |= WAIT_READ; + if(FD_ISSET(s, write)) + events |= WAIT_WRITE; + if(FD_ISSET(s, exception)) + events |= WAIT_EXCEPTION; + return events; +} + +SocketWaitEvent::SocketWaitEvent() +{ + FD_ZERO(read); + FD_ZERO(write); + FD_ZERO(exception); +} + +END_UPP_NAMESPACE diff --git a/uppsrc/Core/Stream.cpp b/uppsrc/Core/Stream.cpp index 6f08db9eb..b49decdcd 100644 --- a/uppsrc/Core/Stream.cpp +++ b/uppsrc/Core/Stream.cpp @@ -74,7 +74,7 @@ void Stream::LoadError() { throw LoadingError(); } -bool Stream::GetAll(void *data, dword size) { +bool Stream::GetAll(void *data, int size) { if(Get(data, size) != size) { LoadError(); return false; @@ -1084,7 +1084,7 @@ void CompareStream::Seek(int64 apos) { ptr = buffer; } -void CompareStream::Compare(int64 pos, const void *data, dword size) { +void CompareStream::Compare(int64 pos, const void *data, int size) { ASSERT(stream); if(!size) return; Buffer b(size); diff --git a/uppsrc/Core/Stream.h b/uppsrc/Core/Stream.h index d747e7ee0..05dc0d577 100644 --- a/uppsrc/Core/Stream.h +++ b/uppsrc/Core/Stream.h @@ -95,7 +95,7 @@ public: void LoadThrowing() { style |= STRM_THROW; } void LoadError(); - bool GetAll(void *data, dword size); + bool GetAll(void *data, int size); int Get8() { return ptr < rdlim ? *ptr++ : _Get8(); } #ifdef CPU_X86 @@ -487,7 +487,7 @@ private: int64 size; byte h[128]; - void Compare(int64 pos, const void *data, dword size); + void Compare(int64 pos, const void *data, int size); public: void Open(Stream& aStream); diff --git a/uppsrc/Core/Web.h b/uppsrc/Core/Web.h index d84d4947a..be31f6c17 100644 --- a/uppsrc/Core/Web.h +++ b/uppsrc/Core/Web.h @@ -1,367 +1,372 @@ -String FormatIP(dword _ip); - -String UrlEncode(const String& s); -String UrlEncode(const String& s, const char *specials); -String UrlDecode(const char *b, const char *e); -inline String UrlDecode(const String& s) { return UrlDecode(s.Begin(), s.End() ); } - -String Base64Encode(const char *b, const char *e); -inline String Base64Encode(const String& data) { return Base64Encode(data.Begin(), data.End()); } -String Base64Decode(const char *b, const char *e); -inline String Base64Decode(const String& data) { return Base64Decode(data.Begin(), data.End()); } - -class IpAddrInfo { - enum { COUNT = 32 }; - struct Entry { - const char *host; - const char *port; - int status; - addrinfo *addr; - }; - static Entry pool[COUNT]; - - enum { - EMPTY = 0, WORKING, CANCELED, RESOLVED, FAILED - }; - - String host, port; - Entry *entry; - Entry exe[1]; - - static void EnterPool(); - static void LeavePool(); - static rawthread_t rawthread__ Thread(void *ptr); - - void Start(); - -public: - void Start(const String& host, int port); - bool InProgress(); - bool Execute(const String& host, int port); - addrinfo *GetResult(); - void Clear(); - - IpAddrInfo(); - ~IpAddrInfo() { Clear(); } -}; - -enum { WAIT_READ = 1, WAIT_WRITE = 2, WAIT_EXCEPTION = 4, WAIT_ALL = 7 }; - -class TcpSocket { - enum { BUFFERSIZE = 512 }; - enum { NONE, CONNECT, ACCEPT }; - SOCKET socket; - int mode; - char buffer[BUFFERSIZE]; - char *ptr; - char *end; - bool is_eof; - bool is_error; - bool is_abort; - bool ipv6; - - int timeout; - int waitstep; - int done; - - int errorcode; - String errordesc; - - struct SSL { - virtual bool Start(TcpSocket& s) = 0; - virtual bool Wait(TcpSocket& s, dword flags) = 0; - virtual int Send(TcpSocket& s, const void *buffer, int maxlen) = 0; - virtual int Recv(TcpSocket& s, void *buffer, int maxlen) = 0; - virtual void Close(TcpSocket& s) = 0; - }; - - struct SSLImp; - - friend struct SSLImp; - - One ssl; - - static SSL *(*CreateSSL)(); - - SSLImp *CreateSSLImp(); - friend void InitCreateSSL(); - - bool RawWait(dword flags); - SOCKET AcceptRaw(dword *ipaddr, int timeout_msec); - bool Open(int family, int type, int protocol); - int RawRecv(void *buffer, int maxlen); - int Recv(void *buffer, int maxlen); - int RawSend(const void *buffer, int maxlen); - int Send(const void *buffer, int maxlen); - bool RawConnect(addrinfo *info); - void RawClose(); - - void ReadBuffer(); - int Get_(); - int Peek_(); - - void Reset(); - - void SetSockError(const char *context, const char *errdesc); - void SetSockError(const char *context); - - static int GetErrorCode(); - static bool WouldBlock(); - -public: - Callback WhenWait; - - static String GetHostName(); - - int GetDone() const { return done; } - - static void Init(); - - bool IsOpen() const { return socket != INVALID_SOCKET; } - bool IsEof() const { return is_eof && ptr == end; } - - bool IsError() const { return is_error; } - void ClearError() { is_error = false; errorcode = 0; errordesc.Clear(); } - int GetError() const { return errorcode; } - String GetErrorDesc() const { return errordesc; } - - void Abort() { is_abort = true; } - bool IsAbort() const { return is_abort; } - void ClearAbort() { is_abort = false; } - - SOCKET GetSOCKET() const { return socket; } - String GetPeerAddr() const; - - void Attach(SOCKET socket); - bool Connect(const char *host, int port); - bool Connect(IpAddrInfo& info); - bool Listen(int port, int listen_count, bool ipv6 = false, bool reuse = true); - bool Accept(TcpSocket& listen_socket); - void Close(); - void Shutdown(); - - void NoDelay(); - void Linger(int msecs); - void NoLinger() { Linger(Null); } - void Reuse(bool reuse = true); - - bool Wait(dword events); - bool WaitRead() { return Wait(WAIT_READ); } - bool WaitWrite() { return Wait(WAIT_WRITE); } - - int Peek() { return ptr < end ? *ptr : Peek_(); } - int Term() { return Peek(); } - int Get() { return ptr < end ? *ptr++ : Get_(); } - int Get(void *buffer, int len); - String Get(int len); - int GetAll(void *buffer, int len) { return Get(buffer, len) == len; } - String GetAll(int len) { String s = Get(len); return s.GetCount() == len ? s : String::GetVoid(); } - String GetLine(int maxlen = 2000000); - - int Put(const char *s, int len); - int Put(const String& s) { return Put(s.Begin(), s.GetLength()); } - bool PutAll(const char *s, int len) { return Put(s, len) == len; } - bool PutAll(const String& s) { return Put(s) == s.GetCount(); } - - bool StartSSL(); - - TcpSocket& Timeout(int ms) { timeout = ms; return *this; } - int GetTimeout() const { return timeout; } - TcpSocket& Blocking() { return Timeout(Null); } - - TcpSocket(); - ~TcpSocket() { Close(); } -}; - -class SocketWaitEvent { - Vector< Tuple2 > socket; - fd_set read[1], write[1], exception[1]; - -public: - void Clear() { socket.Clear(); } - void Add(SOCKET s, dword events = WAIT_ALL) { socket.Add(MakeTuple((int)s, events)); } - void Add(TcpSocket& s, dword events = WAIT_ALL) { Add(s.GetSOCKET(), events); } - int Wait(int timeout); - dword Get(int i) const; - dword operator[](int i) const { return Get(i); } - - SocketWaitEvent(); -}; - -struct HttpHeader { - String first_line; - VectorMap fields; - - String operator[](const char *id) { return fields.Get(id, Null); } - - bool Response(String& protocol, int& code, String& reason); - bool Request(String& method, String& uri, String& version); - - void Clear(); - bool Parse(const String& hdrs); -}; - -class HttpRequest : public TcpSocket { - int phase; - String data; - int count; - - HttpHeader header; - - String error; - String body; - - enum { - DEFAULT_HTTP_PORT = 80, - }; - - enum { - METHOD_GET, - METHOD_POST, - METHOD_HEAD, - METHOD_PUT, - }; - - int max_header_size; - int max_content_size; - int max_redirects; - int max_retries; - int timeout; - - String host; - int port; - String proxy_host; - int proxy_port; - String proxy_username; - String proxy_password; - String path; - - int method; - String accept; - String agent; - bool force_digest; - bool is_post; - bool std_headers; - bool hasurlvar; - String contenttype; - String username; - String password; - String digest; - String request_headers; - String postdata; - - String protocol; - int status_code; - String reason_phrase; - - int start_time; - int retry_count; - int redirect_count; - - int chunk; - - IpAddrInfo addrinfo; - int bodylen; - bool gzip; - Zlib z; - - void Init(); - - void StartPhase(int s); - void Start(); - void Dns(); - void StartRequest(); - bool SendingData(); - bool ReadingHeader(); - void StartBody(); - bool ReadingBody(); - void ReadingChunkHeader(); - void Finish(); - - void HttpError(const char *s); - void ContentOut(const void *ptr, dword size); - void Out(const void *ptr, dword size); - - String CalculateDigest(const String& authenticate) const; - -public: - Callback2 WhenContent; - - HttpRequest& MaxHeaderSize(int m) { max_header_size = m; return *this; } - HttpRequest& MaxContentSize(int m) { max_content_size = m; return *this; } - HttpRequest& MaxRedirect(int n) { max_redirects = n; return *this; } - HttpRequest& MaxRetries(int n) { max_retries = n; return *this; } - HttpRequest& RequestTimeout(int ms) { timeout = ms; return *this; } - HttpRequest& ChunkSize(int n) { chunk = n; return *this; } - - HttpRequest& Method(int m) { method = m; return *this; } - HttpRequest& GET() { return Method(METHOD_GET); } - HttpRequest& POST() { return Method(METHOD_POST); } - HttpRequest& HEAD() { return Method(METHOD_HEAD); } - HttpRequest& PUT() { return Method(METHOD_PUT); } - - HttpRequest& Host(const String& h) { host = h; return *this; } - HttpRequest& Port(int p) { port = p; return *this; } - HttpRequest& Path(const String& p) { path = p; return *this; } - HttpRequest& User(const String& u, const String& p) { username = u; password = p; return *this; } - HttpRequest& Digest() { force_digest = true; return *this; } - HttpRequest& Digest(const String& d) { digest = d; return *this; } - HttpRequest& Url(const char *url); - HttpRequest& UrlVar(const char *id, const String& data); - HttpRequest& operator()(const char *id, const String& data) { return UrlVar(id, data); } - HttpRequest& PostData(const String& pd) { postdata = pd; return *this; } - HttpRequest& PostUData(const String& pd) { return PostData(UrlEncode(pd)); } - HttpRequest& Post(const String& data) { POST(); return PostData(data); } - HttpRequest& Post(const char *id, const String& data); - - HttpRequest& Headers(const String& h) { request_headers = h; return *this; } - HttpRequest& ClearHeaders() { return Headers(Null); } - HttpRequest& AddHeaders(const String& h) { request_headers.Cat(h); return *this; } - HttpRequest& Header(const char *id, const String& data); - - HttpRequest& StdHeaders(bool sh) { std_headers = sh; return *this; } - HttpRequest& NoStdHeaders() { return StdHeaders(false); } - HttpRequest& Accept(const String& a) { accept = a; return *this; } - HttpRequest& Agent(const String& a) { agent = a; return *this; } - HttpRequest& ContentType(const String& a) { contenttype = a; return *this; } - - HttpRequest& Proxy(const String& host, int port) { proxy_host = host; proxy_port = port; return *this; } - HttpRequest& Proxy(const char *url); - HttpRequest& ProxyAuth(const String& u, const String& p) { proxy_username = u; proxy_password = p; return *this; } - - bool IsSocketError() const { return TcpSocket::IsError(); } - bool IsHttpError() const { return !IsNull(error) ; } - bool IsError() const { return IsSocketError() || IsHttpError(); } - String GetErrorDesc() const { return IsSocketError() ? TcpSocket::GetErrorDesc() : error; } - void ClearError() { TcpSocket::ClearError(); error.Clear(); } - - String GetHeader(const char *s) { return header[s]; } - String operator[](const char *s) { return GetHeader(s); } - String GetRedirectUrl(); - int GetContentLength(); - int GetStatusCode() const { return status_code; } - String GetReasonPhrase() const { return reason_phrase; } - - String GetContent() const { return body; } - String operator~() const { return GetContent(); } - operator String() const { return GetContent(); } - void ClearContent() { body.Clear(); } - - enum Phase { - START, DNS, REQUEST, HEADER, BODY, CHUNK_HEADER, CHUNK_BODY, TRAILER, FINISHED, FAILED - }; - - bool Do(); - int GetPhase() const { return phase; } - String GetPhaseName() const; - bool InProgress() const { return phase != FAILED && phase != FINISHED; } - bool IsFailure() const { return phase == FAILED; } - bool IsSuccess() const { return phase == FINISHED && status_code >= 200 && status_code < 300; } - - String Execute(); - - HttpRequest(); - HttpRequest(const char *url); - - static void Trace(bool b = true); -}; +String FormatIP(dword _ip); + +String UrlEncode(const String& s); +String UrlEncode(const String& s, const char *specials); +String UrlDecode(const char *b, const char *e); +inline String UrlDecode(const String& s) { return UrlDecode(s.Begin(), s.End() ); } + +String Base64Encode(const char *b, const char *e); +inline String Base64Encode(const String& data) { return Base64Encode(data.Begin(), data.End()); } +String Base64Decode(const char *b, const char *e); +inline String Base64Decode(const String& data) { return Base64Decode(data.Begin(), data.End()); } + +class IpAddrInfo { + enum { COUNT = 32 }; + struct Entry { + const char *host; + const char *port; + int status; + addrinfo *addr; + }; + static Entry pool[COUNT]; + + enum { + EMPTY = 0, WORKING, CANCELED, RESOLVED, FAILED + }; + + String host, port; + Entry *entry; + Entry exe[1]; + + static void EnterPool(); + static void LeavePool(); + static rawthread_t rawthread__ Thread(void *ptr); + + void Start(); + +public: + void Start(const String& host, int port); + bool InProgress(); + bool Execute(const String& host, int port); + addrinfo *GetResult(); + void Clear(); + + IpAddrInfo(); + ~IpAddrInfo() { Clear(); } +}; + +enum { WAIT_READ = 1, WAIT_WRITE = 2, WAIT_EXCEPTION = 4, WAIT_ALL = 7 }; + +class TcpSocket { + enum { BUFFERSIZE = 512 }; + enum { NONE, CONNECT, ACCEPT }; + SOCKET socket; + int mode; + char buffer[BUFFERSIZE]; + char *ptr; + char *end; + bool is_eof; + bool is_error; + bool is_abort; + bool ipv6; + + int timeout; + int waitstep; + int done; + + int errorcode; + String errordesc; + + struct SSL { + virtual bool Start() = 0; + virtual bool Wait(dword flags) = 0; + virtual int Send(const void *buffer, int maxlen) = 0; + virtual int Recv(void *buffer, int maxlen) = 0; + virtual void Close() = 0; + + virtual ~SSL() {} + }; + + One ssl; + + struct SSLImp; + friend struct SSLImp; + + static SSL *(*CreateSSL)(TcpSocket& socket); + static SSL *CreateSSLImp(TcpSocket& socket); + + friend void InitCreateSSL(); + + bool RawWait(dword flags); + SOCKET AcceptRaw(dword *ipaddr, int timeout_msec); + bool Open(int family, int type, int protocol); + int RawRecv(void *buffer, int maxlen); + int Recv(void *buffer, int maxlen); + int RawSend(const void *buffer, int maxlen); + int Send(const void *buffer, int maxlen); + bool RawConnect(addrinfo *info); + void RawClose(); + + void ReadBuffer(); + int Get_(); + int Peek_(); + + void Reset(); + + void SetSockError(const char *context, int code, const char *errdesc); + void SetSockError(const char *context, const char *errdesc); + void SetSockError(const char *context); + + static int GetErrorCode(); + static bool WouldBlock(); + +public: + Callback WhenWait; + + static String GetHostName(); + + int GetDone() const { return done; } + + static void Init(); + + bool IsOpen() const { return socket != INVALID_SOCKET; } + bool IsEof() const { return is_eof && ptr == end; } + + bool IsError() const { return is_error; } + void ClearError() { is_error = false; errorcode = 0; errordesc.Clear(); } + int GetError() const { return errorcode; } + String GetErrorDesc() const { return errordesc; } + + void Abort() { is_abort = true; } + bool IsAbort() const { return is_abort; } + void ClearAbort() { is_abort = false; } + + SOCKET GetSOCKET() const { return socket; } + String GetPeerAddr() const; + + void Attach(SOCKET socket); + bool Connect(const char *host, int port); + bool Connect(IpAddrInfo& info); + bool Listen(int port, int listen_count, bool ipv6 = false, bool reuse = true); + bool Accept(TcpSocket& listen_socket); + void Close(); + void Shutdown(); + + void NoDelay(); + void Linger(int msecs); + void NoLinger() { Linger(Null); } + void Reuse(bool reuse = true); + + bool Wait(dword events); + bool WaitRead() { return Wait(WAIT_READ); } + bool WaitWrite() { return Wait(WAIT_WRITE); } + + int Peek() { return ptr < end ? *ptr : Peek_(); } + int Term() { return Peek(); } + int Get() { return ptr < end ? *ptr++ : Get_(); } + int Get(void *buffer, int len); + String Get(int len); + int GetAll(void *buffer, int len) { return Get(buffer, len) == len; } + String GetAll(int len) { String s = Get(len); return s.GetCount() == len ? s : String::GetVoid(); } + String GetLine(int maxlen = 2000000); + + int Put(const char *s, int len); + int Put(const String& s) { return Put(s.Begin(), s.GetLength()); } + bool PutAll(const char *s, int len) { return Put(s, len) == len; } + bool PutAll(const String& s) { return Put(s) == s.GetCount(); } + + bool StartSSL(); + bool IsSSL() const { return ssl; } + + TcpSocket& Timeout(int ms) { timeout = ms; return *this; } + int GetTimeout() const { return timeout; } + TcpSocket& Blocking() { return Timeout(Null); } + + TcpSocket(); + ~TcpSocket() { Close(); } +}; + +class SocketWaitEvent { + Vector< Tuple2 > socket; + fd_set read[1], write[1], exception[1]; + +public: + void Clear() { socket.Clear(); } + void Add(SOCKET s, dword events = WAIT_ALL) { socket.Add(MakeTuple((int)s, events)); } + void Add(TcpSocket& s, dword events = WAIT_ALL) { Add(s.GetSOCKET(), events); } + int Wait(int timeout); + dword Get(int i) const; + dword operator[](int i) const { return Get(i); } + + SocketWaitEvent(); +}; + +struct HttpHeader { + String first_line; + VectorMap fields; + + String operator[](const char *id) { return fields.Get(id, Null); } + + bool Response(String& protocol, int& code, String& reason); + bool Request(String& method, String& uri, String& version); + + void Clear(); + bool Parse(const String& hdrs); +}; + +class HttpRequest : public TcpSocket { + int phase; + String data; + int count; + + HttpHeader header; + + String error; + String body; + + enum { + DEFAULT_HTTP_PORT = 80, + }; + + enum { + METHOD_GET, + METHOD_POST, + METHOD_HEAD, + METHOD_PUT, + }; + + int max_header_size; + int max_content_size; + int max_redirects; + int max_retries; + int timeout; + + String host; + int port; + String proxy_host; + int proxy_port; + String proxy_username; + String proxy_password; + String path; + bool ssl; + + int method; + String accept; + String agent; + bool force_digest; + bool is_post; + bool std_headers; + bool hasurlvar; + String contenttype; + String username; + String password; + String digest; + String request_headers; + String postdata; + + String protocol; + int status_code; + String reason_phrase; + + int start_time; + int retry_count; + int redirect_count; + + int chunk; + + IpAddrInfo addrinfo; + int bodylen; + bool gzip; + Zlib z; + + void Init(); + + void StartPhase(int s); + void Start(); + void Dns(); + void StartRequest(); + bool SendingData(); + bool ReadingHeader(); + void StartBody(); + bool ReadingBody(); + void ReadingChunkHeader(); + void Finish(); + + void HttpError(const char *s); + void ContentOut(const void *ptr, dword size); + void Out(const void *ptr, dword size); + + String CalculateDigest(const String& authenticate) const; + +public: + Callback2 WhenContent; + + HttpRequest& MaxHeaderSize(int m) { max_header_size = m; return *this; } + HttpRequest& MaxContentSize(int m) { max_content_size = m; return *this; } + HttpRequest& MaxRedirect(int n) { max_redirects = n; return *this; } + HttpRequest& MaxRetries(int n) { max_retries = n; return *this; } + HttpRequest& RequestTimeout(int ms) { timeout = ms; return *this; } + HttpRequest& ChunkSize(int n) { chunk = n; return *this; } + + HttpRequest& Method(int m) { method = m; return *this; } + HttpRequest& GET() { return Method(METHOD_GET); } + HttpRequest& POST() { return Method(METHOD_POST); } + HttpRequest& HEAD() { return Method(METHOD_HEAD); } + HttpRequest& PUT() { return Method(METHOD_PUT); } + + HttpRequest& Host(const String& h) { host = h; return *this; } + HttpRequest& Port(int p) { port = p; return *this; } + HttpRequest& SSL(bool b = true) { ssl = b; return *this; } + HttpRequest& Path(const String& p) { path = p; return *this; } + HttpRequest& User(const String& u, const String& p) { username = u; password = p; return *this; } + HttpRequest& Digest() { force_digest = true; return *this; } + HttpRequest& Digest(const String& d) { digest = d; return *this; } + HttpRequest& Url(const char *url); + HttpRequest& UrlVar(const char *id, const String& data); + HttpRequest& operator()(const char *id, const String& data) { return UrlVar(id, data); } + HttpRequest& PostData(const String& pd) { postdata = pd; return *this; } + HttpRequest& PostUData(const String& pd) { return PostData(UrlEncode(pd)); } + HttpRequest& Post(const String& data) { POST(); return PostData(data); } + HttpRequest& Post(const char *id, const String& data); + + HttpRequest& Headers(const String& h) { request_headers = h; return *this; } + HttpRequest& ClearHeaders() { return Headers(Null); } + HttpRequest& AddHeaders(const String& h) { request_headers.Cat(h); return *this; } + HttpRequest& Header(const char *id, const String& data); + + HttpRequest& StdHeaders(bool sh) { std_headers = sh; return *this; } + HttpRequest& NoStdHeaders() { return StdHeaders(false); } + HttpRequest& Accept(const String& a) { accept = a; return *this; } + HttpRequest& Agent(const String& a) { agent = a; return *this; } + HttpRequest& ContentType(const String& a) { contenttype = a; return *this; } + + HttpRequest& Proxy(const String& host, int port) { proxy_host = host; proxy_port = port; return *this; } + HttpRequest& Proxy(const char *url); + HttpRequest& ProxyAuth(const String& u, const String& p) { proxy_username = u; proxy_password = p; return *this; } + + bool IsSocketError() const { return TcpSocket::IsError(); } + bool IsHttpError() const { return !IsNull(error) ; } + bool IsError() const { return IsSocketError() || IsHttpError(); } + String GetErrorDesc() const { return IsSocketError() ? TcpSocket::GetErrorDesc() : error; } + void ClearError() { TcpSocket::ClearError(); error.Clear(); } + + String GetHeader(const char *s) { return header[s]; } + String operator[](const char *s) { return GetHeader(s); } + String GetRedirectUrl(); + int GetContentLength(); + int GetStatusCode() const { return status_code; } + String GetReasonPhrase() const { return reason_phrase; } + + String GetContent() const { return body; } + String operator~() const { return GetContent(); } + operator String() const { return GetContent(); } + void ClearContent() { body.Clear(); } + + enum Phase { + START, DNS, REQUEST, HEADER, BODY, CHUNK_HEADER, CHUNK_BODY, TRAILER, FINISHED, FAILED + }; + + bool Do(); + int GetPhase() const { return phase; } + String GetPhaseName() const; + bool InProgress() const { return phase != FAILED && phase != FINISHED; } + bool IsFailure() const { return phase == FAILED; } + bool IsSuccess() const { return phase == FINISHED && status_code >= 200 && status_code < 300; } + + String Execute(); + + HttpRequest(); + HttpRequest(const char *url); + + static void Trace(bool b = true); +}; diff --git a/uppsrc/Core/src.tpp/Stream$en-us.tpp b/uppsrc/Core/src.tpp/Stream$en-us.tpp index b1943fb99..e7e55d322 100644 --- a/uppsrc/Core/src.tpp/Stream$en-us.tpp +++ b/uppsrc/Core/src.tpp/Stream$en-us.tpp @@ -410,6 +410,14 @@ Returns the number of bytes actually read.&] result as String.&] [s3; &] [s4;%- &] +[s5;:Stream`:`:GetAll`(void`*`,int`):%- [@(0.0.255) bool]_[* GetAll]([@(0.0.255) void]_`*[*@3 d +ata], [@(0.0.255) int]_[*@3 size])&] +[s2; Reads [%-*@3 size] bytes from the stream to memory at [%-*@3 data]. +If there is not enough data in the stream, LoadError is invoked +(that in turn might throw an exception). Returns true if required +number of bytes was read.&] +[s3; &] +[s4;%- &] [s5;:Stream`:`:LoadThrowing`(`):%- [@(0.0.255) void]_[* LoadThrowing]()&] [s2; Sets stream into the mode that throws LoadingError exception when LoadError is invoked. This mode is typical for serialization @@ -422,16 +430,6 @@ the LoadThrowing mode (by LoadThrowing() method), LoadingError exception is thrown.&] [s3; &] [s4;%- &] -[s5;:Stream`:`:GetAll`(void`*`,dword`):%- [@(0.0.255) bool]_[* GetAll]([@(0.0.255) void]_`* -[*@3 data], [_^dword^ dword]_[*@3 size])&] -[s2; Reads a block of raw binary data from the stream. If there is -not enough data in the stream, LoadError is invoked (that in -turn might throw an exception).&] -[s7; [%-*C@3 data]-|Pointer to buffer to receive the data.&] -[s7; [%-*C@3 size]-|Number of bytes to read.&] -[s7; [*/ Return value]-|true if required number of bytes was read.&] -[s3; &] -[s4;%- &] [s5;:Stream`:`:Get8`(`):%- [@(0.0.255) int]_[* Get8]()&] [s2; Reads a single byte from the stream. If there is not enough data in the stream, LoadError is invoked (that in turn might