| /* |
| * Copyright (c) 2017, The OpenThread Authors. |
| * All rights reserved. |
| * |
| * Redistribution and use in source and binary forms, with or without |
| * modification, are permitted provided that the following conditions are met: |
| * 1. Redistributions of source code must retain the above copyright |
| * notice, this list of conditions and the following disclaimer. |
| * 2. Redistributions in binary form must reproduce the above copyright |
| * notice, this list of conditions and the following disclaimer in the |
| * documentation and/or other materials provided with the distribution. |
| * 3. Neither the name of the copyright holder nor the |
| * names of its contributors may be used to endorse or promote products |
| * derived from this software without specific prior written permission. |
| * |
| * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" |
| * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE |
| * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE |
| * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE |
| * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR |
| * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF |
| * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS |
| * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN |
| * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) |
| * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE |
| * POSSIBILITY OF SUCH DAMAGE. |
| */ |
| |
| #include <openthread/config.h> |
| |
| #include "dns_client.hpp" |
| |
| #include "utils/wrap_string.h" |
| |
| #include "common/code_utils.hpp" |
| #include "common/debug.hpp" |
| #include "net/udp6.hpp" |
| #include "thread/thread_netif.hpp" |
| |
| #if OPENTHREAD_ENABLE_DNS_CLIENT |
| |
| /** |
| * @file |
| * This file implements the DNS client. |
| */ |
| |
| using ot::Encoding::BigEndian::HostSwap16; |
| |
| namespace ot { |
| namespace Dns { |
| |
| otError Client::Start(void) |
| { |
| otError error; |
| Ip6::SockAddr addr; |
| |
| SuccessOrExit(error = mSocket.Open(&Client::HandleUdpReceive, this)); |
| SuccessOrExit(error = mSocket.Bind(addr)); |
| |
| exit: |
| return error; |
| } |
| |
| otError Client::Stop(void) |
| { |
| Message *message = mPendingQueries.GetHead(); |
| Message *messageToRemove; |
| QueryMetadata queryMetadata; |
| |
| // Remove all pending queries. |
| while (message != NULL) |
| { |
| messageToRemove = message; |
| message = message->GetNext(); |
| |
| queryMetadata.ReadFrom(*messageToRemove); |
| FinalizeDnsTransaction(*messageToRemove, queryMetadata, NULL, 0, OT_ERROR_ABORT); |
| } |
| |
| return mSocket.Close(); |
| } |
| |
| otError Client::Query(const otDnsQuery *aQuery, otDnsResponseHandler aHandler, void *aContext) |
| { |
| otError error; |
| QueryMetadata queryMetadata(aHandler, aContext); |
| Message *message = NULL; |
| Message *messageCopy = NULL; |
| Header header; |
| QuestionAaaa question; |
| const Ip6::MessageInfo *messageInfo; |
| |
| VerifyOrExit(aQuery->mHostname != NULL && aQuery->mMessageInfo != NULL, |
| error = OT_ERROR_INVALID_ARGS); |
| |
| header.SetMessageId(mMessageId++); |
| header.SetType(Header::kTypeQuery); |
| header.SetQueryType(Header::kQueryTypeStandard); |
| |
| if (!aQuery->mNoRecursion) |
| { |
| header.SetRecursionDesiredFlag(); |
| } |
| |
| header.SetQuestionCount(1); |
| |
| VerifyOrExit((message = NewMessage(header)) != NULL, error = OT_ERROR_NO_BUFS); |
| |
| SuccessOrExit(error = AppendCompressedHostname(*message, aQuery->mHostname)); |
| SuccessOrExit(error = question.AppendTo(*message)); |
| |
| messageInfo = static_cast<const Ip6::MessageInfo *>(aQuery->mMessageInfo); |
| |
| queryMetadata.mHostname = aQuery->mHostname; |
| queryMetadata.mTransmissionTime = Timer::GetNow() + kResponseTimeout; |
| queryMetadata.mSourceAddress = messageInfo->GetSockAddr(); |
| queryMetadata.mDestinationPort = messageInfo->GetPeerPort(); |
| queryMetadata.mDestinationAddress = messageInfo->GetPeerAddr(); |
| queryMetadata.mRetransmissionCount = 0; |
| |
| VerifyOrExit((messageCopy = CopyAndEnqueueMessage(*message, queryMetadata)) != NULL, |
| error = OT_ERROR_NO_BUFS); |
| SuccessOrExit(error = SendMessage(*message, *messageInfo)); |
| |
| exit: |
| |
| if (error != OT_ERROR_NONE) |
| { |
| if (message) |
| { |
| message->Free(); |
| } |
| |
| if (messageCopy) |
| { |
| DequeueMessage(*messageCopy); |
| } |
| } |
| |
| return error; |
| } |
| |
| Message *Client::NewMessage(const Header &aHeader) |
| { |
| Message *message = NULL; |
| |
| VerifyOrExit((message = mSocket.NewMessage(sizeof(aHeader))) != NULL); |
| message->Prepend(&aHeader, sizeof(aHeader)); |
| message->SetOffset(0); |
| |
| exit: |
| return message; |
| } |
| |
| Message *Client::CopyAndEnqueueMessage(const Message &aMessage, const QueryMetadata &aQueryMetadata) |
| { |
| otError error = OT_ERROR_NONE; |
| uint32_t now = Timer::GetNow(); |
| Message *messageCopy = NULL; |
| uint32_t nextTransmissionTime; |
| |
| // Create a message copy for further retransmissions. |
| VerifyOrExit((messageCopy = aMessage.Clone()) != NULL, error = OT_ERROR_NO_BUFS); |
| |
| // Append the copy with retransmission data and add it to the queue. |
| SuccessOrExit(error = aQueryMetadata.AppendTo(*messageCopy)); |
| mPendingQueries.Enqueue(*messageCopy); |
| |
| // Setup the timer. |
| if (mRetransmissionTimer.IsRunning()) |
| { |
| // If timer is already running, check if it should be restarted with earlier fire time. |
| nextTransmissionTime = mRetransmissionTimer.GetFireTime(); |
| |
| if (aQueryMetadata.IsEarlier(nextTransmissionTime)) |
| { |
| mRetransmissionTimer.Start(aQueryMetadata.mTransmissionTime - now); |
| } |
| } |
| else |
| { |
| mRetransmissionTimer.Start(aQueryMetadata.mTransmissionTime - now); |
| } |
| |
| exit: |
| |
| if (error != OT_ERROR_NONE && messageCopy != NULL) |
| { |
| messageCopy->Free(); |
| messageCopy = NULL; |
| } |
| |
| return messageCopy; |
| } |
| |
| void Client::DequeueMessage(Message &aMessage) |
| { |
| mPendingQueries.Dequeue(aMessage); |
| |
| if (mRetransmissionTimer.IsRunning() && (mPendingQueries.GetHead() == NULL)) |
| { |
| // No more requests pending, stop the timer. |
| mRetransmissionTimer.Stop(); |
| } |
| |
| // Free the message memory. |
| aMessage.Free(); |
| } |
| |
| otError Client::SendMessage(Message &aMessage, const Ip6::MessageInfo &aMessageInfo) |
| { |
| return mSocket.SendTo(aMessage, aMessageInfo); |
| } |
| |
| otError Client::SendCopy(const Message &aMessage, const Ip6::MessageInfo &aMessageInfo) |
| { |
| otError error; |
| Message *messageCopy = NULL; |
| |
| // Create a message copy for lower layers. |
| VerifyOrExit((messageCopy = aMessage.Clone(aMessage.GetLength() - sizeof(QueryMetadata))) != NULL, |
| error = OT_ERROR_NO_BUFS); |
| |
| // Send the copy. |
| SuccessOrExit(error = SendMessage(*messageCopy, aMessageInfo)); |
| |
| exit: |
| |
| if (error != OT_ERROR_NONE && messageCopy != NULL) |
| { |
| messageCopy->Free(); |
| } |
| |
| return error; |
| } |
| |
| otError Client::AppendCompressedHostname(Message &aMessage, const char *aHostname) |
| { |
| otError error = OT_ERROR_NONE; |
| uint8_t index = 0; |
| uint8_t labelPosition = 0; |
| uint8_t labelSize = 0; |
| |
| while (true) |
| { |
| // Look for string separator. |
| if (aHostname[index] == kLabelSeparator || aHostname[index] == kLabelTerminator) |
| { |
| VerifyOrExit(labelSize > 0, error = OT_ERROR_INVALID_ARGS); |
| SuccessOrExit(error = aMessage.Append(&labelSize, 1)); |
| SuccessOrExit(error = aMessage.Append(&aHostname[labelPosition], labelSize)); |
| |
| labelPosition += labelSize + 1; |
| labelSize = 0; |
| |
| if (aHostname[index] == kLabelTerminator) |
| { |
| break; |
| } |
| } |
| else |
| { |
| labelSize++; |
| } |
| |
| index++; |
| } |
| |
| // Add termination character at the end. |
| labelSize = kLabelTerminator; |
| SuccessOrExit(error = aMessage.Append(&labelSize, 1)); |
| |
| exit: |
| return error; |
| } |
| |
| otError Client::CompareQuestions(Message &aMessageResponse, Message &aMessageQuery, uint16_t &aOffset) |
| { |
| otError error = OT_ERROR_NONE; |
| uint8_t bufQuery[kBufSize]; |
| uint8_t bufResponse[kBufSize]; |
| uint16_t read = 0; |
| |
| // Compare question section of the query with the response. |
| uint16_t length = aMessageQuery.GetLength() - aMessageQuery.GetOffset() - |
| sizeof(Header) - sizeof(QueryMetadata); |
| uint16_t offset = aMessageQuery.GetOffset() + sizeof(Header); |
| |
| while (length > 0) |
| { |
| VerifyOrExit((read = aMessageQuery.Read(offset, |
| length < sizeof(bufQuery) ? length : sizeof(bufQuery), |
| bufQuery)) > 0, error = OT_ERROR_PARSE); |
| VerifyOrExit(aMessageResponse.Read(aOffset, read, bufResponse) == read, |
| error = OT_ERROR_PARSE); |
| |
| VerifyOrExit(memcmp(bufResponse, bufQuery, read) == 0, error = OT_ERROR_NOT_FOUND); |
| |
| aOffset += read; |
| offset += read; |
| length -= read; |
| } |
| |
| exit: |
| return error; |
| } |
| |
| otError Client::SkipHostname(Message &aMessage, uint16_t &aOffset) |
| { |
| otError error = OT_ERROR_NONE; |
| uint8_t buf[kBufSize]; |
| uint16_t index; |
| uint16_t read = 0; |
| uint16_t offset = aOffset; |
| uint16_t length = aMessage.GetLength() - aOffset; |
| |
| while (length > 0) |
| { |
| VerifyOrExit((read = aMessage.Read(offset, sizeof(buf), buf)) > 0, |
| error = OT_ERROR_PARSE); |
| |
| index = 0; |
| |
| while (index < read) |
| { |
| if (buf[index] == kLabelTerminator) |
| { |
| ExitNow(aOffset = offset + 1); |
| } |
| |
| if ((buf[index] & kCompressionOffsetMask) == kCompressionOffsetMask) |
| { |
| ExitNow(aOffset = offset + 2); |
| } |
| |
| index++; |
| offset++; |
| } |
| |
| length -= read; |
| } |
| |
| ExitNow(error = OT_ERROR_PARSE); |
| |
| exit: |
| return error; |
| } |
| |
| Message *Client::FindRelatedQuery(const Header &aResponseHeader, QueryMetadata &aQueryMetadata) |
| { |
| uint16_t messageId; |
| Message *message = mPendingQueries.GetHead(); |
| |
| while (message != NULL) |
| { |
| // Partially read DNS header to obtain message ID only. |
| assert(message->Read(message->GetOffset(), sizeof(messageId), &messageId) == sizeof(messageId)); |
| |
| if (HostSwap16(messageId) == aResponseHeader.GetMessageId()) |
| { |
| aQueryMetadata.ReadFrom(*message); |
| ExitNow(); |
| } |
| |
| message = message->GetNext(); |
| } |
| |
| exit: |
| return message; |
| } |
| |
| void Client::FinalizeDnsTransaction(Message &aQuery, const QueryMetadata &aQueryMetadata, |
| otIp6Address *aAddress, uint32_t aTtl, |
| otError aResult) |
| { |
| DequeueMessage(aQuery); |
| |
| if (aQueryMetadata.mResponseHandler != NULL) |
| { |
| aQueryMetadata.mResponseHandler(aQueryMetadata.mResponseContext, aQueryMetadata.mHostname, |
| aAddress, aTtl, aResult); |
| } |
| } |
| |
| void Client::HandleRetransmissionTimer(Timer &aTimer) |
| { |
| GetOwner(aTimer).HandleRetransmissionTimer(); |
| } |
| |
| void Client::HandleRetransmissionTimer(void) |
| { |
| uint32_t now = Timer::GetNow(); |
| uint32_t nextDelta = 0xffffffff; |
| QueryMetadata queryMetadata; |
| Message *message = mPendingQueries.GetHead(); |
| Message *nextMessage = NULL; |
| Ip6::MessageInfo messageInfo; |
| |
| while (message != NULL) |
| { |
| nextMessage = message->GetNext(); |
| queryMetadata.ReadFrom(*message); |
| |
| if (queryMetadata.IsLater(now)) |
| { |
| // Calculate the next delay and choose the lowest. |
| if (queryMetadata.mTransmissionTime - now < nextDelta) |
| { |
| nextDelta = queryMetadata.mTransmissionTime - now; |
| } |
| } |
| else if (queryMetadata.mRetransmissionCount < kMaxRetransmit) |
| { |
| // Increment retransmission counter and timer. |
| queryMetadata.mRetransmissionCount++; |
| queryMetadata.mTransmissionTime = now + kResponseTimeout; |
| queryMetadata.UpdateIn(*message); |
| |
| // Check if retransmission time is lower than current lowest. |
| if (queryMetadata.mTransmissionTime - now < nextDelta) |
| { |
| nextDelta = queryMetadata.mTransmissionTime - now; |
| } |
| |
| // Retransmit |
| messageInfo.SetPeerAddr(queryMetadata.mDestinationAddress); |
| messageInfo.SetPeerPort(queryMetadata.mDestinationPort); |
| messageInfo.SetSockAddr(queryMetadata.mSourceAddress); |
| |
| SendCopy(*message, messageInfo); |
| } |
| else |
| { |
| // No expected response. |
| FinalizeDnsTransaction(*message, queryMetadata, NULL, 0, OT_ERROR_RESPONSE_TIMEOUT); |
| } |
| |
| message = nextMessage; |
| } |
| |
| if (nextDelta != 0xffffffff) |
| { |
| mRetransmissionTimer.Start(nextDelta); |
| } |
| } |
| |
| void Client::HandleUdpReceive(void *aContext, otMessage *aMessage, const otMessageInfo *aMessageInfo) |
| { |
| static_cast<Client *>(aContext)->HandleUdpReceive(*static_cast<Message *>(aMessage), |
| *static_cast<const Ip6::MessageInfo *>(aMessageInfo)); |
| } |
| |
| void Client::HandleUdpReceive(Message &aMessage, const Ip6::MessageInfo &aMessageInfo) |
| { |
| otError error = OT_ERROR_NONE; |
| Header responseHeader; |
| QueryMetadata queryMetadata; |
| ResourceRecordAaaa record; |
| Message *message = NULL; |
| uint16_t offset; |
| |
| // RFC1035 7.3. Resolver cannot rely that a response will come from the same address |
| // which it sent the corresponding query to. |
| OT_UNUSED_VARIABLE(aMessageInfo); |
| |
| VerifyOrExit(aMessage.Read(aMessage.GetOffset(), sizeof(responseHeader), &responseHeader) == |
| sizeof(responseHeader)); |
| VerifyOrExit(responseHeader.GetType() == Header::kTypeResponse && |
| responseHeader.GetQuestionCount() == 1 && |
| responseHeader.IsTruncationFlagSet() == false); |
| |
| aMessage.MoveOffset(sizeof(responseHeader)); |
| offset = aMessage.GetOffset(); |
| |
| VerifyOrExit((message = FindRelatedQuery(responseHeader, queryMetadata)) != NULL); |
| |
| if (responseHeader.GetResponseCode() != Header::kResponseSuccess) |
| { |
| ExitNow(error = OT_ERROR_FAILED); |
| } |
| |
| // Parse and check the question section. |
| SuccessOrExit(error = CompareQuestions(aMessage, *message, offset)); |
| |
| // Parse and check the answer section. |
| for (uint32_t index = 0; index < responseHeader.GetAnswerCount(); index++) |
| { |
| SuccessOrExit(error = SkipHostname(aMessage, offset)); |
| |
| if (offset + sizeof(ResourceRecord) > aMessage.GetLength()) |
| { |
| ExitNow(error = OT_ERROR_PARSE); |
| } |
| |
| if (aMessage.Read(offset, sizeof(record), &record) != sizeof(record) || |
| record.GetType() != ResourceRecordAaaa::kType || |
| record.GetClass() != ResourceRecordAaaa::kClass) |
| { |
| offset += sizeof(ResourceRecord) + record.GetLength(); |
| |
| continue; |
| } |
| |
| // Return the first found IPv6 address. |
| FinalizeDnsTransaction(*message, queryMetadata, &record.GetAddress(), record.GetTtl(), OT_ERROR_NONE); |
| |
| ExitNow(); |
| } |
| |
| ExitNow(error = OT_ERROR_NOT_FOUND); |
| |
| exit: |
| |
| if (message != NULL && error != OT_ERROR_NONE) |
| { |
| FinalizeDnsTransaction(*message, queryMetadata, NULL, 0, error); |
| } |
| |
| return; |
| } |
| |
| Client &Client::GetOwner(const Context &aContext) |
| { |
| #if OPENTHREAD_ENABLE_MULTIPLE_INSTANCES |
| Client &client = *static_cast<Client *>(aContext.GetContext()); |
| #else |
| Client &client = otGetThreadNetif().GetDnsClient(); |
| OT_UNUSED_VARIABLE(aContext); |
| #endif |
| return client; |
| } |
| |
| } // namespace Coap |
| } // namespace ot |
| |
| #endif // OPENTHREAD_ENABLE_DNS_CLIENT |