524 lines
		
	
	
		
			16 KiB
		
	
	
	
		
			C++
		
	
	
	
			
		
		
	
	
			524 lines
		
	
	
		
			16 KiB
		
	
	
	
		
			C++
		
	
	
	
| /*
 | |
|  *  Copyright 2004 The WebRTC Project Authors. All rights reserved.
 | |
|  *
 | |
|  *  Use of this source code is governed by a BSD-style license
 | |
|  *  that can be found in the LICENSE file in the root of the source
 | |
|  *  tree. An additional intellectual property rights grant can be found
 | |
|  *  in the file PATENTS.  All contributing project authors may
 | |
|  *  be found in the AUTHORS file in the root of the source tree.
 | |
|  */
 | |
| 
 | |
| #include "rtc_base/nat_socket_factory.h"
 | |
| 
 | |
| #include "rtc_base/arraysize.h"
 | |
| #include "rtc_base/checks.h"
 | |
| #include "rtc_base/logging.h"
 | |
| #include "rtc_base/nat_server.h"
 | |
| #include "rtc_base/virtual_socket_server.h"
 | |
| 
 | |
| namespace rtc {
 | |
| 
 | |
| // Packs the given socketaddress into the buffer in buf, in the quasi-STUN
 | |
| // format that the natserver uses.
 | |
| // Returns 0 if an invalid address is passed.
 | |
| size_t PackAddressForNAT(char* buf,
 | |
|                          size_t buf_size,
 | |
|                          const SocketAddress& remote_addr) {
 | |
|   const IPAddress& ip = remote_addr.ipaddr();
 | |
|   int family = ip.family();
 | |
|   buf[0] = 0;
 | |
|   buf[1] = family;
 | |
|   // Writes the port.
 | |
|   *(reinterpret_cast<uint16_t*>(&buf[2])) = HostToNetwork16(remote_addr.port());
 | |
|   if (family == AF_INET) {
 | |
|     RTC_DCHECK(buf_size >= kNATEncodedIPv4AddressSize);
 | |
|     in_addr v4addr = ip.ipv4_address();
 | |
|     memcpy(&buf[4], &v4addr, kNATEncodedIPv4AddressSize - 4);
 | |
|     return kNATEncodedIPv4AddressSize;
 | |
|   } else if (family == AF_INET6) {
 | |
|     RTC_DCHECK(buf_size >= kNATEncodedIPv6AddressSize);
 | |
|     in6_addr v6addr = ip.ipv6_address();
 | |
|     memcpy(&buf[4], &v6addr, kNATEncodedIPv6AddressSize - 4);
 | |
|     return kNATEncodedIPv6AddressSize;
 | |
|   }
 | |
|   return 0U;
 | |
| }
 | |
| 
 | |
| // Decodes the remote address from a packet that has been encoded with the nat's
 | |
| // quasi-STUN format. Returns the length of the address (i.e., the offset into
 | |
| // data where the original packet starts).
 | |
| size_t UnpackAddressFromNAT(const char* buf,
 | |
|                             size_t buf_size,
 | |
|                             SocketAddress* remote_addr) {
 | |
|   RTC_DCHECK(buf_size >= 8);
 | |
|   RTC_DCHECK(buf[0] == 0);
 | |
|   int family = buf[1];
 | |
|   uint16_t port =
 | |
|       NetworkToHost16(*(reinterpret_cast<const uint16_t*>(&buf[2])));
 | |
|   if (family == AF_INET) {
 | |
|     const in_addr* v4addr = reinterpret_cast<const in_addr*>(&buf[4]);
 | |
|     *remote_addr = SocketAddress(IPAddress(*v4addr), port);
 | |
|     return kNATEncodedIPv4AddressSize;
 | |
|   } else if (family == AF_INET6) {
 | |
|     RTC_DCHECK(buf_size >= 20);
 | |
|     const in6_addr* v6addr = reinterpret_cast<const in6_addr*>(&buf[4]);
 | |
|     *remote_addr = SocketAddress(IPAddress(*v6addr), port);
 | |
|     return kNATEncodedIPv6AddressSize;
 | |
|   }
 | |
|   return 0U;
 | |
| }
 | |
| 
 | |
| // NATSocket
 | |
| class NATSocket : public AsyncSocket, public sigslot::has_slots<> {
 | |
|  public:
 | |
|   explicit NATSocket(NATInternalSocketFactory* sf, int family, int type)
 | |
|       : sf_(sf),
 | |
|         family_(family),
 | |
|         type_(type),
 | |
|         connected_(false),
 | |
|         socket_(nullptr),
 | |
|         buf_(nullptr),
 | |
|         size_(0) {}
 | |
| 
 | |
|   ~NATSocket() override {
 | |
|     delete socket_;
 | |
|     delete[] buf_;
 | |
|   }
 | |
| 
 | |
|   SocketAddress GetLocalAddress() const override {
 | |
|     return (socket_) ? socket_->GetLocalAddress() : SocketAddress();
 | |
|   }
 | |
| 
 | |
|   SocketAddress GetRemoteAddress() const override {
 | |
|     return remote_addr_;  // will be NIL if not connected
 | |
|   }
 | |
| 
 | |
|   int Bind(const SocketAddress& addr) override {
 | |
|     if (socket_) {  // already bound, bubble up error
 | |
|       return -1;
 | |
|     }
 | |
| 
 | |
|     return BindInternal(addr);
 | |
|   }
 | |
| 
 | |
|   int Connect(const SocketAddress& addr) override {
 | |
|     int result = 0;
 | |
|     // If we're not already bound (meaning |socket_| is null), bind to ANY
 | |
|     // address.
 | |
|     if (!socket_) {
 | |
|       result = BindInternal(SocketAddress(GetAnyIP(family_), 0));
 | |
|       if (result < 0) {
 | |
|         return result;
 | |
|       }
 | |
|     }
 | |
| 
 | |
|     if (type_ == SOCK_STREAM) {
 | |
|       result = socket_->Connect(server_addr_.IsNil() ? addr : server_addr_);
 | |
|     } else {
 | |
|       connected_ = true;
 | |
|     }
 | |
| 
 | |
|     if (result >= 0) {
 | |
|       remote_addr_ = addr;
 | |
|     }
 | |
| 
 | |
|     return result;
 | |
|   }
 | |
| 
 | |
|   int Send(const void* data, size_t size) override {
 | |
|     RTC_DCHECK(connected_);
 | |
|     return SendTo(data, size, remote_addr_);
 | |
|   }
 | |
| 
 | |
|   int SendTo(const void* data,
 | |
|              size_t size,
 | |
|              const SocketAddress& addr) override {
 | |
|     RTC_DCHECK(!connected_ || addr == remote_addr_);
 | |
|     if (server_addr_.IsNil() || type_ == SOCK_STREAM) {
 | |
|       return socket_->SendTo(data, size, addr);
 | |
|     }
 | |
|     // This array will be too large for IPv4 packets, but only by 12 bytes.
 | |
|     std::unique_ptr<char[]> buf(new char[size + kNATEncodedIPv6AddressSize]);
 | |
|     size_t addrlength =
 | |
|         PackAddressForNAT(buf.get(), size + kNATEncodedIPv6AddressSize, addr);
 | |
|     size_t encoded_size = size + addrlength;
 | |
|     memcpy(buf.get() + addrlength, data, size);
 | |
|     int result = socket_->SendTo(buf.get(), encoded_size, server_addr_);
 | |
|     if (result >= 0) {
 | |
|       RTC_DCHECK(result == static_cast<int>(encoded_size));
 | |
|       result = result - static_cast<int>(addrlength);
 | |
|     }
 | |
|     return result;
 | |
|   }
 | |
| 
 | |
|   int Recv(void* data, size_t size, int64_t* timestamp) override {
 | |
|     SocketAddress addr;
 | |
|     return RecvFrom(data, size, &addr, timestamp);
 | |
|   }
 | |
| 
 | |
|   int RecvFrom(void* data,
 | |
|                size_t size,
 | |
|                SocketAddress* out_addr,
 | |
|                int64_t* timestamp) override {
 | |
|     if (server_addr_.IsNil() || type_ == SOCK_STREAM) {
 | |
|       return socket_->RecvFrom(data, size, out_addr, timestamp);
 | |
|     }
 | |
|     // Make sure we have enough room to read the requested amount plus the
 | |
|     // largest possible header address.
 | |
|     SocketAddress remote_addr;
 | |
|     Grow(size + kNATEncodedIPv6AddressSize);
 | |
| 
 | |
|     // Read the packet from the socket.
 | |
|     int result = socket_->RecvFrom(buf_, size_, &remote_addr, timestamp);
 | |
|     if (result >= 0) {
 | |
|       RTC_DCHECK(remote_addr == server_addr_);
 | |
| 
 | |
|       // TODO: we need better framing so we know how many bytes we can
 | |
|       // return before we need to read the next address. For UDP, this will be
 | |
|       // fine as long as the reader always reads everything in the packet.
 | |
|       RTC_DCHECK((size_t)result < size_);
 | |
| 
 | |
|       // Decode the wire packet into the actual results.
 | |
|       SocketAddress real_remote_addr;
 | |
|       size_t addrlength = UnpackAddressFromNAT(buf_, result, &real_remote_addr);
 | |
|       memcpy(data, buf_ + addrlength, result - addrlength);
 | |
| 
 | |
|       // Make sure this packet should be delivered before returning it.
 | |
|       if (!connected_ || (real_remote_addr == remote_addr_)) {
 | |
|         if (out_addr)
 | |
|           *out_addr = real_remote_addr;
 | |
|         result = result - static_cast<int>(addrlength);
 | |
|       } else {
 | |
|         RTC_LOG(LS_ERROR) << "Dropping packet from unknown remote address: "
 | |
|                           << real_remote_addr.ToString();
 | |
|         result = 0;  // Tell the caller we didn't read anything
 | |
|       }
 | |
|     }
 | |
| 
 | |
|     return result;
 | |
|   }
 | |
| 
 | |
|   int Close() override {
 | |
|     int result = 0;
 | |
|     if (socket_) {
 | |
|       result = socket_->Close();
 | |
|       if (result >= 0) {
 | |
|         connected_ = false;
 | |
|         remote_addr_ = SocketAddress();
 | |
|         delete socket_;
 | |
|         socket_ = nullptr;
 | |
|       }
 | |
|     }
 | |
|     return result;
 | |
|   }
 | |
| 
 | |
|   int Listen(int backlog) override { return socket_->Listen(backlog); }
 | |
|   AsyncSocket* Accept(SocketAddress* paddr) override {
 | |
|     return socket_->Accept(paddr);
 | |
|   }
 | |
|   int GetError() const override {
 | |
|     return socket_ ? socket_->GetError() : error_;
 | |
|   }
 | |
|   void SetError(int error) override {
 | |
|     if (socket_) {
 | |
|       socket_->SetError(error);
 | |
|     } else {
 | |
|       error_ = error;
 | |
|     }
 | |
|   }
 | |
|   ConnState GetState() const override {
 | |
|     return connected_ ? CS_CONNECTED : CS_CLOSED;
 | |
|   }
 | |
|   int GetOption(Option opt, int* value) override {
 | |
|     return socket_->GetOption(opt, value);
 | |
|   }
 | |
|   int SetOption(Option opt, int value) override {
 | |
|     return socket_->SetOption(opt, value);
 | |
|   }
 | |
| 
 | |
|   void OnConnectEvent(AsyncSocket* socket) {
 | |
|     // If we're NATed, we need to send a message with the real addr to use.
 | |
|     RTC_DCHECK(socket == socket_);
 | |
|     if (server_addr_.IsNil()) {
 | |
|       connected_ = true;
 | |
|       SignalConnectEvent(this);
 | |
|     } else {
 | |
|       SendConnectRequest();
 | |
|     }
 | |
|   }
 | |
|   void OnReadEvent(AsyncSocket* socket) {
 | |
|     // If we're NATed, we need to process the connect reply.
 | |
|     RTC_DCHECK(socket == socket_);
 | |
|     if (type_ == SOCK_STREAM && !server_addr_.IsNil() && !connected_) {
 | |
|       HandleConnectReply();
 | |
|     } else {
 | |
|       SignalReadEvent(this);
 | |
|     }
 | |
|   }
 | |
|   void OnWriteEvent(AsyncSocket* socket) {
 | |
|     RTC_DCHECK(socket == socket_);
 | |
|     SignalWriteEvent(this);
 | |
|   }
 | |
|   void OnCloseEvent(AsyncSocket* socket, int error) {
 | |
|     RTC_DCHECK(socket == socket_);
 | |
|     SignalCloseEvent(this, error);
 | |
|   }
 | |
| 
 | |
|  private:
 | |
|   int BindInternal(const SocketAddress& addr) {
 | |
|     RTC_DCHECK(!socket_);
 | |
| 
 | |
|     int result;
 | |
|     socket_ = sf_->CreateInternalSocket(family_, type_, addr, &server_addr_);
 | |
|     result = (socket_) ? socket_->Bind(addr) : -1;
 | |
|     if (result >= 0) {
 | |
|       socket_->SignalConnectEvent.connect(this, &NATSocket::OnConnectEvent);
 | |
|       socket_->SignalReadEvent.connect(this, &NATSocket::OnReadEvent);
 | |
|       socket_->SignalWriteEvent.connect(this, &NATSocket::OnWriteEvent);
 | |
|       socket_->SignalCloseEvent.connect(this, &NATSocket::OnCloseEvent);
 | |
|     } else {
 | |
|       server_addr_.Clear();
 | |
|       delete socket_;
 | |
|       socket_ = nullptr;
 | |
|     }
 | |
| 
 | |
|     return result;
 | |
|   }
 | |
| 
 | |
|   // Makes sure the buffer is at least the given size.
 | |
|   void Grow(size_t new_size) {
 | |
|     if (size_ < new_size) {
 | |
|       delete[] buf_;
 | |
|       size_ = new_size;
 | |
|       buf_ = new char[size_];
 | |
|     }
 | |
|   }
 | |
| 
 | |
|   // Sends the destination address to the server to tell it to connect.
 | |
|   void SendConnectRequest() {
 | |
|     char buf[kNATEncodedIPv6AddressSize];
 | |
|     size_t length = PackAddressForNAT(buf, arraysize(buf), remote_addr_);
 | |
|     socket_->Send(buf, length);
 | |
|   }
 | |
| 
 | |
|   // Handles the byte sent back from the server and fires the appropriate event.
 | |
|   void HandleConnectReply() {
 | |
|     char code;
 | |
|     socket_->Recv(&code, sizeof(code), nullptr);
 | |
|     if (code == 0) {
 | |
|       connected_ = true;
 | |
|       SignalConnectEvent(this);
 | |
|     } else {
 | |
|       Close();
 | |
|       SignalCloseEvent(this, code);
 | |
|     }
 | |
|   }
 | |
| 
 | |
|   NATInternalSocketFactory* sf_;
 | |
|   int family_;
 | |
|   int type_;
 | |
|   bool connected_;
 | |
|   SocketAddress remote_addr_;
 | |
|   SocketAddress server_addr_;  // address of the NAT server
 | |
|   AsyncSocket* socket_;
 | |
|   // Need to hold error in case it occurs before the socket is created.
 | |
|   int error_ = 0;
 | |
|   char* buf_;
 | |
|   size_t size_;
 | |
| };
 | |
| 
 | |
| // NATSocketFactory
 | |
| NATSocketFactory::NATSocketFactory(SocketFactory* factory,
 | |
|                                    const SocketAddress& nat_udp_addr,
 | |
|                                    const SocketAddress& nat_tcp_addr)
 | |
|     : factory_(factory),
 | |
|       nat_udp_addr_(nat_udp_addr),
 | |
|       nat_tcp_addr_(nat_tcp_addr) {}
 | |
| 
 | |
| Socket* NATSocketFactory::CreateSocket(int family, int type) {
 | |
|   return new NATSocket(this, family, type);
 | |
| }
 | |
| 
 | |
| AsyncSocket* NATSocketFactory::CreateAsyncSocket(int family, int type) {
 | |
|   return new NATSocket(this, family, type);
 | |
| }
 | |
| 
 | |
| AsyncSocket* NATSocketFactory::CreateInternalSocket(
 | |
|     int family,
 | |
|     int type,
 | |
|     const SocketAddress& local_addr,
 | |
|     SocketAddress* nat_addr) {
 | |
|   if (type == SOCK_STREAM) {
 | |
|     *nat_addr = nat_tcp_addr_;
 | |
|   } else {
 | |
|     *nat_addr = nat_udp_addr_;
 | |
|   }
 | |
|   return factory_->CreateAsyncSocket(family, type);
 | |
| }
 | |
| 
 | |
| // NATSocketServer
 | |
| NATSocketServer::NATSocketServer(SocketServer* server)
 | |
|     : server_(server), msg_queue_(nullptr) {}
 | |
| 
 | |
| NATSocketServer::Translator* NATSocketServer::GetTranslator(
 | |
|     const SocketAddress& ext_ip) {
 | |
|   return nats_.Get(ext_ip);
 | |
| }
 | |
| 
 | |
| NATSocketServer::Translator* NATSocketServer::AddTranslator(
 | |
|     const SocketAddress& ext_ip,
 | |
|     const SocketAddress& int_ip,
 | |
|     NATType type) {
 | |
|   // Fail if a translator already exists with this extternal address.
 | |
|   if (nats_.Get(ext_ip))
 | |
|     return nullptr;
 | |
| 
 | |
|   return nats_.Add(ext_ip, new Translator(this, type, int_ip, server_, ext_ip));
 | |
| }
 | |
| 
 | |
| void NATSocketServer::RemoveTranslator(const SocketAddress& ext_ip) {
 | |
|   nats_.Remove(ext_ip);
 | |
| }
 | |
| 
 | |
| Socket* NATSocketServer::CreateSocket(int family, int type) {
 | |
|   return new NATSocket(this, family, type);
 | |
| }
 | |
| 
 | |
| AsyncSocket* NATSocketServer::CreateAsyncSocket(int family, int type) {
 | |
|   return new NATSocket(this, family, type);
 | |
| }
 | |
| 
 | |
| void NATSocketServer::SetMessageQueue(Thread* queue) {
 | |
|   msg_queue_ = queue;
 | |
|   server_->SetMessageQueue(queue);
 | |
| }
 | |
| 
 | |
| bool NATSocketServer::Wait(int cms, bool process_io) {
 | |
|   return server_->Wait(cms, process_io);
 | |
| }
 | |
| 
 | |
| void NATSocketServer::WakeUp() {
 | |
|   server_->WakeUp();
 | |
| }
 | |
| 
 | |
| AsyncSocket* NATSocketServer::CreateInternalSocket(
 | |
|     int family,
 | |
|     int type,
 | |
|     const SocketAddress& local_addr,
 | |
|     SocketAddress* nat_addr) {
 | |
|   AsyncSocket* socket = nullptr;
 | |
|   Translator* nat = nats_.FindClient(local_addr);
 | |
|   if (nat) {
 | |
|     socket = nat->internal_factory()->CreateAsyncSocket(family, type);
 | |
|     *nat_addr = (type == SOCK_STREAM) ? nat->internal_tcp_address()
 | |
|                                       : nat->internal_udp_address();
 | |
|   } else {
 | |
|     socket = server_->CreateAsyncSocket(family, type);
 | |
|   }
 | |
|   return socket;
 | |
| }
 | |
| 
 | |
| // NATSocketServer::Translator
 | |
| NATSocketServer::Translator::Translator(NATSocketServer* server,
 | |
|                                         NATType type,
 | |
|                                         const SocketAddress& int_ip,
 | |
|                                         SocketFactory* ext_factory,
 | |
|                                         const SocketAddress& ext_ip)
 | |
|     : server_(server) {
 | |
|   // Create a new private network, and a NATServer running on the private
 | |
|   // network that bridges to the external network. Also tell the private
 | |
|   // network to use the same message queue as us.
 | |
|   VirtualSocketServer* internal_server = new VirtualSocketServer();
 | |
|   internal_server->SetMessageQueue(server_->queue());
 | |
|   internal_factory_.reset(internal_server);
 | |
|   nat_server_.reset(new NATServer(type, internal_server, int_ip, int_ip,
 | |
|                                   ext_factory, ext_ip));
 | |
| }
 | |
| 
 | |
| NATSocketServer::Translator::~Translator() = default;
 | |
| 
 | |
| NATSocketServer::Translator* NATSocketServer::Translator::GetTranslator(
 | |
|     const SocketAddress& ext_ip) {
 | |
|   return nats_.Get(ext_ip);
 | |
| }
 | |
| 
 | |
| NATSocketServer::Translator* NATSocketServer::Translator::AddTranslator(
 | |
|     const SocketAddress& ext_ip,
 | |
|     const SocketAddress& int_ip,
 | |
|     NATType type) {
 | |
|   // Fail if a translator already exists with this extternal address.
 | |
|   if (nats_.Get(ext_ip))
 | |
|     return nullptr;
 | |
| 
 | |
|   AddClient(ext_ip);
 | |
|   return nats_.Add(ext_ip,
 | |
|                    new Translator(server_, type, int_ip, server_, ext_ip));
 | |
| }
 | |
| void NATSocketServer::Translator::RemoveTranslator(
 | |
|     const SocketAddress& ext_ip) {
 | |
|   nats_.Remove(ext_ip);
 | |
|   RemoveClient(ext_ip);
 | |
| }
 | |
| 
 | |
| bool NATSocketServer::Translator::AddClient(const SocketAddress& int_ip) {
 | |
|   // Fail if a client already exists with this internal address.
 | |
|   if (clients_.find(int_ip) != clients_.end())
 | |
|     return false;
 | |
| 
 | |
|   clients_.insert(int_ip);
 | |
|   return true;
 | |
| }
 | |
| 
 | |
| void NATSocketServer::Translator::RemoveClient(const SocketAddress& int_ip) {
 | |
|   std::set<SocketAddress>::iterator it = clients_.find(int_ip);
 | |
|   if (it != clients_.end()) {
 | |
|     clients_.erase(it);
 | |
|   }
 | |
| }
 | |
| 
 | |
| NATSocketServer::Translator* NATSocketServer::Translator::FindClient(
 | |
|     const SocketAddress& int_ip) {
 | |
|   // See if we have the requested IP, or any of our children do.
 | |
|   return (clients_.find(int_ip) != clients_.end()) ? this
 | |
|                                                    : nats_.FindClient(int_ip);
 | |
| }
 | |
| 
 | |
| // NATSocketServer::TranslatorMap
 | |
| NATSocketServer::TranslatorMap::~TranslatorMap() {
 | |
|   for (TranslatorMap::iterator it = begin(); it != end(); ++it) {
 | |
|     delete it->second;
 | |
|   }
 | |
| }
 | |
| 
 | |
| NATSocketServer::Translator* NATSocketServer::TranslatorMap::Get(
 | |
|     const SocketAddress& ext_ip) {
 | |
|   TranslatorMap::iterator it = find(ext_ip);
 | |
|   return (it != end()) ? it->second : nullptr;
 | |
| }
 | |
| 
 | |
| NATSocketServer::Translator* NATSocketServer::TranslatorMap::Add(
 | |
|     const SocketAddress& ext_ip,
 | |
|     Translator* nat) {
 | |
|   (*this)[ext_ip] = nat;
 | |
|   return nat;
 | |
| }
 | |
| 
 | |
| void NATSocketServer::TranslatorMap::Remove(const SocketAddress& ext_ip) {
 | |
|   TranslatorMap::iterator it = find(ext_ip);
 | |
|   if (it != end()) {
 | |
|     delete it->second;
 | |
|     erase(it);
 | |
|   }
 | |
| }
 | |
| 
 | |
| NATSocketServer::Translator* NATSocketServer::TranslatorMap::FindClient(
 | |
|     const SocketAddress& int_ip) {
 | |
|   Translator* nat = nullptr;
 | |
|   for (TranslatorMap::iterator it = begin(); it != end() && !nat; ++it) {
 | |
|     nat = it->second->FindClient(int_ip);
 | |
|   }
 | |
|   return nat;
 | |
| }
 | |
| 
 | |
| }  // namespace rtc
 |