224 lines
		
	
	
	
		
			5.5 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			224 lines
		
	
	
	
		
			5.5 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
/*
 | 
						|
This file is part of Telegram Desktop,
 | 
						|
the official desktop application for the Telegram messaging service.
 | 
						|
 | 
						|
For license and copyright information please follow this link:
 | 
						|
https://github.com/telegramdesktop/tdesktop/blob/master/LEGAL
 | 
						|
*/
 | 
						|
#include "mtproto/concurrent_sender.h"
 | 
						|
 | 
						|
#include "mtproto/mtp_instance.h"
 | 
						|
#include "mtproto/rpc_sender.h"
 | 
						|
 | 
						|
namespace MTP {
 | 
						|
 | 
						|
class ConcurrentSender::RPCDoneHandler : public RPCAbstractDoneHandler {
 | 
						|
public:
 | 
						|
	RPCDoneHandler(
 | 
						|
		not_null<ConcurrentSender*> sender,
 | 
						|
		Fn<void(FnMut<void()>)> runner);
 | 
						|
 | 
						|
	void operator()(
 | 
						|
		mtpRequestId requestId,
 | 
						|
		const mtpPrime *from,
 | 
						|
		const mtpPrime *end) override;
 | 
						|
 | 
						|
private:
 | 
						|
	base::weak_ptr<ConcurrentSender> _weak;
 | 
						|
	Fn<void(FnMut<void()>)> _runner;
 | 
						|
 | 
						|
};
 | 
						|
 | 
						|
class ConcurrentSender::RPCFailHandler : public RPCAbstractFailHandler {
 | 
						|
public:
 | 
						|
	RPCFailHandler(
 | 
						|
		not_null<ConcurrentSender*> sender,
 | 
						|
		Fn<void(FnMut<void()>)> runner,
 | 
						|
		FailSkipPolicy skipPolicy);
 | 
						|
 | 
						|
	bool operator()(
 | 
						|
		mtpRequestId requestId,
 | 
						|
		const RPCError &error) override;
 | 
						|
 | 
						|
private:
 | 
						|
	base::weak_ptr<ConcurrentSender> _weak;
 | 
						|
	Fn<void(FnMut<void()>)> _runner;
 | 
						|
	FailSkipPolicy _skipPolicy = FailSkipPolicy::Simple;
 | 
						|
 | 
						|
};
 | 
						|
 | 
						|
ConcurrentSender::RPCDoneHandler::RPCDoneHandler(
 | 
						|
	not_null<ConcurrentSender*> sender,
 | 
						|
	Fn<void(FnMut<void()>)> runner)
 | 
						|
: _weak(sender)
 | 
						|
, _runner(std::move(runner)) {
 | 
						|
}
 | 
						|
 | 
						|
void ConcurrentSender::RPCDoneHandler::operator()(
 | 
						|
		mtpRequestId requestId,
 | 
						|
		const mtpPrime *from,
 | 
						|
		const mtpPrime *end) {
 | 
						|
	auto response = gsl::make_span(
 | 
						|
		from,
 | 
						|
		end - from);
 | 
						|
	_runner([=, weak = _weak, moved = bytes::make_vector(response)]() mutable {
 | 
						|
		if (const auto strong = weak.get()) {
 | 
						|
			strong->senderRequestDone(requestId, std::move(moved));
 | 
						|
		}
 | 
						|
	});
 | 
						|
}
 | 
						|
 | 
						|
ConcurrentSender::RPCFailHandler::RPCFailHandler(
 | 
						|
	not_null<ConcurrentSender*> sender,
 | 
						|
	Fn<void(FnMut<void()>)> runner,
 | 
						|
	FailSkipPolicy skipPolicy)
 | 
						|
: _weak(sender)
 | 
						|
, _runner(std::move(runner))
 | 
						|
, _skipPolicy(skipPolicy) {
 | 
						|
}
 | 
						|
 | 
						|
bool ConcurrentSender::RPCFailHandler::operator()(
 | 
						|
		mtpRequestId requestId,
 | 
						|
		const RPCError &error) {
 | 
						|
	if (_skipPolicy == FailSkipPolicy::Simple) {
 | 
						|
		if (MTP::isDefaultHandledError(error)) {
 | 
						|
			return false;
 | 
						|
		}
 | 
						|
	} else if (_skipPolicy == FailSkipPolicy::HandleFlood) {
 | 
						|
		if (MTP::isDefaultHandledError(error) && !MTP::isFloodError(error)) {
 | 
						|
			return false;
 | 
						|
		}
 | 
						|
	}
 | 
						|
	_runner([=, weak = _weak, error = error]() mutable {
 | 
						|
		if (const auto strong = weak.get()) {
 | 
						|
			strong->senderRequestFail(requestId, std::move(error));
 | 
						|
		}
 | 
						|
	});
 | 
						|
	return true;
 | 
						|
}
 | 
						|
 | 
						|
template <typename Method>
 | 
						|
auto ConcurrentSender::with_instance(Method &&method)
 | 
						|
-> std::enable_if_t<is_callable_v<Method, not_null<Instance*>>> {
 | 
						|
	crl::on_main([method = std::forward<Method>(method)]() mutable {
 | 
						|
		if (const auto instance = MainInstance()) {
 | 
						|
			std::move(method)(instance);
 | 
						|
		}
 | 
						|
	});
 | 
						|
}
 | 
						|
 | 
						|
ConcurrentSender::RequestBuilder::RequestBuilder(
 | 
						|
	not_null<ConcurrentSender*> sender,
 | 
						|
	SecureRequest &&serialized) noexcept
 | 
						|
: _sender(sender)
 | 
						|
, _serialized(std::move(serialized)) {
 | 
						|
}
 | 
						|
 | 
						|
void ConcurrentSender::RequestBuilder::setToDC(ShiftedDcId dcId) noexcept {
 | 
						|
	_dcId = dcId;
 | 
						|
}
 | 
						|
 | 
						|
void ConcurrentSender::RequestBuilder::setCanWait(crl::time ms) noexcept {
 | 
						|
	_canWait = ms;
 | 
						|
}
 | 
						|
 | 
						|
void ConcurrentSender::RequestBuilder::setFailSkipPolicy(
 | 
						|
		FailSkipPolicy policy) noexcept {
 | 
						|
	_failSkipPolicy = policy;
 | 
						|
}
 | 
						|
 | 
						|
void ConcurrentSender::RequestBuilder::setAfter(
 | 
						|
		mtpRequestId requestId) noexcept {
 | 
						|
	_afterRequestId = requestId;
 | 
						|
}
 | 
						|
 | 
						|
mtpRequestId ConcurrentSender::RequestBuilder::send() {
 | 
						|
	const auto requestId = GetNextRequestId();
 | 
						|
	const auto dcId = _dcId;
 | 
						|
	const auto msCanWait = _canWait;
 | 
						|
	const auto afterRequestId = _afterRequestId;
 | 
						|
 | 
						|
	_sender->senderRequestRegister(requestId, std::move(_handlers));
 | 
						|
	_sender->with_instance([
 | 
						|
		=,
 | 
						|
		request = std::move(_serialized),
 | 
						|
		done = std::make_shared<RPCDoneHandler>(_sender, _sender->_runner),
 | 
						|
		fail = std::make_shared<RPCFailHandler>(
 | 
						|
			_sender,
 | 
						|
			_sender->_runner,
 | 
						|
			_failSkipPolicy)
 | 
						|
	](not_null<Instance*> instance) mutable {
 | 
						|
		instance->sendSerialized(
 | 
						|
			requestId,
 | 
						|
			std::move(request),
 | 
						|
			RPCResponseHandler(std::move(done), std::move(fail)),
 | 
						|
			dcId,
 | 
						|
			msCanWait,
 | 
						|
			afterRequestId);
 | 
						|
	});
 | 
						|
 | 
						|
	return requestId;
 | 
						|
}
 | 
						|
 | 
						|
ConcurrentSender::ConcurrentSender(Fn<void(FnMut<void()>)> runner)
 | 
						|
: _runner(runner) {
 | 
						|
}
 | 
						|
 | 
						|
ConcurrentSender::~ConcurrentSender() {
 | 
						|
	senderRequestCancelAll();
 | 
						|
}
 | 
						|
 | 
						|
void ConcurrentSender::senderRequestRegister(
 | 
						|
		mtpRequestId requestId,
 | 
						|
		Handlers &&handlers) {
 | 
						|
	_requests.emplace(requestId, std::move(handlers));
 | 
						|
}
 | 
						|
 | 
						|
void ConcurrentSender::senderRequestDone(
 | 
						|
		mtpRequestId requestId,
 | 
						|
		bytes::const_span result) {
 | 
						|
	if (auto handlers = _requests.take(requestId)) {
 | 
						|
		try {
 | 
						|
			handlers->done(requestId, result);
 | 
						|
		} catch (Exception &e) {
 | 
						|
			handlers->fail(
 | 
						|
				requestId,
 | 
						|
				RPCError::Local(
 | 
						|
					"RESPONSE_PARSE_FAILED",
 | 
						|
					QString("exception text: ") + e.what()));
 | 
						|
		}
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
void ConcurrentSender::senderRequestFail(
 | 
						|
		mtpRequestId requestId,
 | 
						|
		RPCError &&error) {
 | 
						|
	if (auto handlers = _requests.take(requestId)) {
 | 
						|
		handlers->fail(requestId, std::move(error));
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
void ConcurrentSender::senderRequestCancel(mtpRequestId requestId) {
 | 
						|
	senderRequestDetach(requestId);
 | 
						|
	with_instance([=](not_null<Instance*> instance) {
 | 
						|
		instance->cancel(requestId);
 | 
						|
	});
 | 
						|
}
 | 
						|
 | 
						|
void ConcurrentSender::senderRequestCancelAll() {
 | 
						|
	auto list = std::vector<mtpRequestId>(_requests.size());
 | 
						|
	for (const auto &pair : base::take(_requests)) {
 | 
						|
		list.push_back(pair.first);
 | 
						|
	}
 | 
						|
	with_instance([list = std::move(list)](not_null<Instance*> instance) {
 | 
						|
		for (const auto requestId : list) {
 | 
						|
			instance->cancel(requestId);
 | 
						|
		}
 | 
						|
	});
 | 
						|
}
 | 
						|
 | 
						|
void ConcurrentSender::senderRequestDetach(mtpRequestId requestId) {
 | 
						|
	_requests.erase(requestId);
 | 
						|
}
 | 
						|
 | 
						|
} // namespace MTP
 |