OSDN Git Service

64bit用に修正
[mmo/main.git] / common / network / Session.cpp
1 //
2 // Session.cpp
3 //
4
5 #include "Command.hpp"
6 #include "CommandHeader.hpp"
7 #include "Session.hpp"
8 #include "Utils.hpp"
9 #include "../Logger.hpp"
10 #include <boost/make_shared.hpp>
11 #include <string>
12
13 namespace network {
14
15     Session::Session(boost::asio::io_service& io_service_tcp) :
16       io_service_tcp_(io_service_tcp),
17       socket_tcp_(io_service_tcp),
18       encryption_(false),
19       online_(true),
20       login_(false),
21       read_start_time_(time(nullptr)),
22       write_start_time_(time(nullptr)),
23       read_byte_sum_(0),
24       write_byte_sum_(0),
25       serialized_byte_sum_(0),
26       compressed_byte_sum_(0),
27       id_(0)
28     {
29
30     }
31
32     Session::~Session()
33     {
34         Close();
35     }
36
37     void Session::Close()
38     {
39         socket_tcp_.close();
40     }
41
42     void Session::Send(const Command& command)
43     {
44         auto msg = Serialize(command);
45         write_byte_sum_ += msg.size();
46         UpdateWriteByteAverage();
47
48         io_service_tcp_.post(boost::bind(&Session::DoWriteTCP, this, msg, shared_from_this()));
49     }
50
51     void Session::SyncSend(const Command& command)
52     {
53         auto msg = Serialize(command);
54         write_byte_sum_ += msg.size();
55         UpdateWriteByteAverage();
56
57         try {
58             boost::asio::write(
59                     socket_tcp_, boost::asio::buffer(msg.data(), msg.size()),
60                 boost::asio::transfer_all());
61         } catch (std::exception& e) {
62             std::cout << e.what() << std::endl;
63         }
64     }
65
66     double Session::GetReadByteAverage() const
67     {
68         return 1.0f * read_byte_sum_ / (time(nullptr) - read_start_time_);
69     }
70
71     double Session::GetWriteByteAverage() const
72     {
73         return 1.0f * write_byte_sum_ / (time(nullptr) - write_start_time_);
74     }
75
76     void Session::UpdateReadByteAverage()
77     {
78         unsigned long elapsed_time = time(nullptr) - read_start_time_;
79         if (elapsed_time >= BYTE_AVERAGE_REFRESH_SECONDS) {
80             read_byte_sum_ /= 2;
81             read_start_time_ = time(nullptr) - elapsed_time / 2;
82         }
83     }
84
85     void Session::UpdateWriteByteAverage()
86     {
87         unsigned long elapsed_time = time(nullptr) - write_start_time_;
88         if (elapsed_time >= BYTE_AVERAGE_REFRESH_SECONDS) {
89             write_byte_sum_ /= 2;
90             write_start_time_ = time(nullptr) - elapsed_time / 2;
91         }
92     }
93
94     void Session::EnableEncryption()
95     {
96         encryption_ = true;
97     }
98
99     Encrypter& Session::encrypter()
100     {
101         return encrypter_;
102     }
103
104     tcp::socket& Session::tcp_socket()
105     {
106         return socket_tcp_;
107     }
108
109     UserID Session::id() const
110     {
111         return id_;
112     }
113
114     void Session::set_id(UserID id)
115     {
116         id_ = id;
117     }
118
119     bool Session::online() const
120     {
121         return online_;
122     }
123
124     std::string Session::global_ip() const
125     {
126         return global_ip_;
127     }
128
129     uint16_t Session::udp_port() const{
130         return udp_port_;
131     }
132
133     void Session::set_global_ip(const std::string& global_ip)
134     {
135         global_ip_ = global_ip;
136     }
137
138     void Session::set_udp_port(uint16_t udp_port)
139     {
140         udp_port_ = udp_port;
141     }
142
143     int Session::serialized_byte_sum() const
144     {
145         return serialized_byte_sum_;
146     }
147
148     int Session::compressed_byte_sum() const
149     {
150         return compressed_byte_sum_;
151     }
152
153     bool Session::operator==(const Session& s)
154     {
155         return id_ == s.id_;
156     }
157
158     bool Session::operator!=(const Session& s)
159     {
160         return !operator==(s);
161     }
162
163     std::string Session::Serialize(const Command& command)
164     {
165         assert(command.header() < 0xFF);
166         auto header = static_cast<uint8_t>(command.header());\r
167         std::string body = command.body();
168
169         std::string msg = Utils::Serialize(header) + body;
170
171         // 圧縮
172         if (body.size() >= COMPRESS_MIN_LENGTH) {
173             auto compressed = Utils::LZ4Compress(msg);
174             if (msg.size() > compressed.size() + sizeof(uint8_t)) {\r
175                 assert(msg.size() < 65535);
176                 msg = Utils::Serialize(static_cast<uint8_t>(header::LZ4_COMPRESS_HEADER),\r
177                     static_cast<uint16_t>(msg.size()))\r
178                     + compressed;
179             }
180         }
181
182         // 暗号化
183         if (encryption_) {
184             msg = Utils::Serialize(static_cast<uint8_t>(header::ENCRYPT_HEADER))\r
185                 + encrypter_.Encrypt(msg);
186         }
187
188         return Utils::Encode(msg);
189     }
190
191     Command Session::Deserialize(const std::string& msg)
192     {
193         std::string decoded_msg = Utils::Decode(msg);
194
195         uint8_t header;\r
196         Utils::Deserialize(decoded_msg, &header);
197
198         // 復号
199         if (header == header::ENCRYPT_HEADER) {
200             decoded_msg.erase(0, sizeof(header));
201             decoded_msg = encrypter_.Decrypt(decoded_msg);
202             Utils::Deserialize(decoded_msg, &header);
203         }
204
205         // 伸長
206         if (header == header::LZ4_COMPRESS_HEADER) {
207             uint16_t original_size;\r
208             Utils::Deserialize(decoded_msg, &header, &original_size);
209             decoded_msg.erase(0, sizeof(header) + sizeof(original_size));
210             decoded_msg = Utils::LZ4Uncompress(decoded_msg, original_size);
211             Utils::Deserialize(decoded_msg, &header);
212         }
213
214         std::string body = decoded_msg.substr(sizeof(header));
215
216         return Command(static_cast<header::CommandHeader>(header), body, shared_from_this());
217     }
218
219     void Session::ReceiveTCP(const boost::system::error_code& error)
220     {
221         if (!error) {
222             std::string buffer(boost::asio::buffer_cast<const char*>(receive_buf_.data()),receive_buf_.size());
223             auto length = buffer.find_last_of(NETWORK_UTILS_DELIMITOR);
224
225             if (length != std::string::npos) {
226
227                 receive_buf_.consume(length+1);
228                 buffer.erase(length+1);
229
230                 while (!buffer.empty()) {
231                     std::string msg;
232
233                     while (!buffer.empty() && buffer[0]!=NETWORK_UTILS_DELIMITOR)
234                     {
235                         msg += buffer[0];
236                         buffer.erase(0,1);
237                     }
238                     buffer.erase(0,1);
239
240                     read_byte_sum_ += msg.size();
241                     UpdateReadByteAverage();
242
243                     FetchTCP(msg);
244                 }
245
246                 boost::asio::async_read_until(socket_tcp_,
247                     receive_buf_, NETWORK_UTILS_DELIMITOR,
248                     boost::bind(
249                       &Session::ReceiveTCP, shared_from_this(),
250                       boost::asio::placeholders::error));
251
252             }
253
254         } else {
255             FatalError();
256         }
257     }
258
259     void Session::DoWriteTCP(const std::string msg, SessionPtr session_holder)
260     {
261         bool write_in_progress = !send_queue_.empty();
262         send_queue_.push(msg);
263         if (!write_in_progress && !send_queue_.empty())
264         {
265            
266           boost::shared_ptr<std::string> s = 
267               boost::make_shared<std::string>(msg.data(), msg.size());
268
269           boost::asio::async_write(socket_tcp_,
270               boost::asio::buffer(s->data(), s->size()),
271               boost::bind(&Session::WriteTCP, this,
272                 boost::asio::placeholders::error, s, session_holder));
273         }
274     }
275
276     void Session::WriteTCP(const boost::system::error_code& error,
277                 boost::shared_ptr<std::string> holder, SessionPtr session_holder)
278     {
279         if (!error) {
280             if (!send_queue_.empty()) {
281                   send_queue_.pop();
282                   if (!send_queue_.empty())
283                   {
284
285                     boost::shared_ptr<std::string> s = 
286                         boost::make_shared<std::string>(send_queue_.front().data(), send_queue_.front().size());
287
288                     boost::asio::async_write(socket_tcp_,
289                         boost::asio::buffer(s->data(), s->size()),
290                         boost::bind(&Session::WriteTCP, this,
291                           boost::asio::placeholders::error, s, session_holder));
292                   }
293             }
294         } else {
295             FatalError(session_holder);
296         }
297     }
298
299     void Session::FetchTCP(const std::string& msg)
300     {
301         if (msg.size() >= sizeof(uint8_t)) {\r
302             if (on_receive_) {
303                 (*on_receive_)(Deserialize(msg));
304             }
305         } else {
306             Logger::Error(_T("Too short data"));
307         }
308     }
309
310     void Session::FatalError(SessionPtr session_holder)
311     {
312         if (online_) {
313             online_ = false;
314             if (on_receive_) {
315                 if (id_ > 0) {
316                     (*on_receive_)(FatalConnectionError(id_));
317                 } else {
318                     (*on_receive_)(FatalConnectionError());
319                 }
320             }
321         }
322     }
323
324     void Session::set_on_receive(CallbackFuncPtr func)
325     {
326         on_receive_ = func;
327     }
328
329
330 }