Changeset View
Changeset View
Standalone View
Standalone View
source/tools/rdb/src/main.rs
#![feature(stmt_expr_attributes, custom_attribute, unrestricted_attribute_tokens)] | |||||
#![allow(unused_attributes)] // So it doesn't bug me about sqlify-derive | |||||
#[macro_use] | |||||
extern crate num_derive; | |||||
#[macro_use] | |||||
extern crate diesel; | |||||
#[macro_use] | |||||
extern crate serde_derive; | |||||
#[macro_use] | |||||
extern crate lazy_static; | |||||
pub mod models; | |||||
pub mod schema; | |||||
mod account_actions; | |||||
mod replay_actions; | |||||
mod connection; | |||||
mod filebase; | |||||
mod ip_tracker; | |||||
use diesel::r2d2; | |||||
use native_tls::Identity; | |||||
use std::time::{Duration, Instant}; | |||||
use tokio::codec::Framed; | |||||
use tokio::net::{TcpListener, TcpStream}; | |||||
use tokio::prelude::*; | |||||
use tokio::timer::Interval; | |||||
use tokio_tls::TlsStream; | |||||
use std::fs::File; | |||||
use diesel::prelude::*; | |||||
use std::io::Cursor; | |||||
use flate2::Compression; | |||||
use flate2::bufread::GzDecoder; | |||||
use flate2::write::GzEncoder; | |||||
use std::io::{Read, Write}; | |||||
use tokio::codec::{Decoder, Encoder}; | |||||
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; | |||||
use bytes::{BufMut, BytesMut}; | |||||
use self::account_actions::AccountActions; | |||||
use self::ip_tracker::IPTracker; | |||||
use self::replay_actions::ReplayActions; | |||||
use self::connection::{Msg, ConnState}; | |||||
use self::models::BanType; | |||||
pub type DbConn = r2d2::PooledConnection<r2d2::ConnectionManager<MysqlConnection>>; | |||||
#[derive(Copy, Clone, Debug)] | |||||
pub struct MsgCodec(()); | |||||
impl MsgCodec { | |||||
pub fn new() -> MsgCodec { | |||||
MsgCodec(()) | |||||
} | |||||
} | |||||
impl Decoder for MsgCodec { | |||||
type Item = Msg; | |||||
type Error = std::io::Error; | |||||
fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> { | |||||
if buf.len() < 8 { | |||||
return Ok(None); | |||||
} | |||||
let msg_len = Cursor::new(buf.get(0..8).unwrap()) | |||||
.read_u64::<LittleEndian>() | |||||
.unwrap() as usize; | |||||
if buf.len() < 16 + msg_len { | |||||
return Ok(None); | |||||
} | |||||
buf.advance(8); | |||||
let t_len = Cursor::new(buf.split_to(8)) | |||||
.read_u64::<LittleEndian>() | |||||
.unwrap() as usize; | |||||
let data = buf.split_to(msg_len).clone(); | |||||
let mut gz = GzDecoder::new(&data[..]); | |||||
let mut msg = String::new(); | |||||
gz.read_to_string(&mut msg).unwrap(); | |||||
assert_eq!(msg.len(), t_len); | |||||
Ok(Some(serde_json::from_str(&msg).unwrap())) | |||||
} | |||||
} | |||||
impl Encoder for MsgCodec { | |||||
type Item = Msg; | |||||
type Error = std::io::Error; | |||||
fn encode(&mut self, data: Self::Item, buf: &mut BytesMut) -> Result<(), Self::Error> { | |||||
let send = serde_json::to_string(&data).unwrap(); | |||||
let send = send.as_bytes(); | |||||
let send_len = send.len(); | |||||
let mut gz = GzEncoder::new(Vec::new(), Compression::fast()); | |||||
gz.write_all(send).unwrap(); | |||||
let send: Vec<u8> = gz.finish().unwrap(); | |||||
buf.reserve(send.len() + 16); | |||||
let mut size = vec![]; | |||||
size.write_u64::<LittleEndian>(send.len() as u64).unwrap(); | |||||
size.write_u64::<LittleEndian>(send_len as u64).unwrap(); | |||||
buf.put(size); | |||||
buf.put(send); | |||||
Ok(()) | |||||
} | |||||
} | |||||
type ActionRetType = Result<(), std::io::Error>; | |||||
type WriterType = stream::SplitSink<Framed<TlsStream<TcpStream>, MsgCodec>>; | |||||
fn handle_data( | |||||
conn_state: &mut ConnState, | |||||
writer: &mut stream::SplitSink<Framed<TlsStream<TcpStream>, MsgCodec>>, | |||||
data: Msg, | |||||
db_conn: DbConn, | |||||
) -> Result<(), std::io::Error> { | |||||
match data { | |||||
Msg::Register { .. } if conn_state.username.is_none() => { | |||||
AccountActions::handle_register(conn_state, writer, &data, &db_conn) | |||||
} | |||||
Msg::Login { .. } if conn_state.username.is_none() => { | |||||
AccountActions::handle_login(conn_state, writer, &data, &db_conn) | |||||
} | |||||
_ if conn_state.username.is_none() => { | |||||
writer.start_send(Msg::Error { | |||||
msg: format!("Access denied."), | |||||
})?; | |||||
Ok(()) | |||||
} | |||||
Msg::QueryForUserData { .. } => { | |||||
AccountActions::handle_query_for_user_data(conn_state, writer, &data, &db_conn) | |||||
} | |||||
Msg::QueryForReplayDatas { .. } => { | |||||
ReplayActions::handle_query_for_replay_datas(conn_state, writer, &data, &db_conn) | |||||
} | |||||
Msg::QueryForReplayList { .. } => { | |||||
ReplayActions::handle_query_for_replay_list(conn_state, writer, &data, &db_conn) | |||||
} | |||||
_ => panic!(), | |||||
} | |||||
} | |||||
fn main() { | |||||
dotenv::dotenv().ok(); | |||||
tokio::run(future::ok(()).and_then(|_| server_main())); | |||||
} | |||||
fn server_main() -> Result<(), ()> { | |||||
let db_url = dotenv::var("DATABASE_URL").unwrap(); | |||||
let db_manager = r2d2::ConnectionManager::<MysqlConnection>::new(db_url); | |||||
let db_pool = r2d2::Pool::builder() | |||||
.max_size(50) | |||||
.build(db_manager) | |||||
.unwrap(); | |||||
// Bind the server's socket | |||||
let addr = "127.0.0.1:16180".parse().unwrap(); | |||||
let tcp = TcpListener::bind(&addr).unwrap(); | |||||
// Create the TLS acceptor. | |||||
let mut der = vec![]; | |||||
File::open(dotenv::var("TLS_P12_CERTIFICATE").unwrap()) | |||||
.unwrap() | |||||
.read_to_end(&mut der) | |||||
.unwrap(); | |||||
let cert = Identity::from_pkcs12(&der, &dotenv::var("TLS_PASSWORD").unwrap()).unwrap(); | |||||
let tls_acceptor = | |||||
tokio_tls::TlsAcceptor::from(native_tls::TlsAcceptor::builder(cert).build().unwrap()); | |||||
{ | |||||
let db_pool = db_pool.clone(); | |||||
let clean_ips_task = | |||||
Interval::new(Instant::now(), Duration::from_secs(60 * 60 * 24)) // once an day | |||||
.map_err(|e| panic!("Clean ips timer failed; err={:?}", e)) | |||||
.for_each(move |instant| { | |||||
println!("Started clean ips at {:?}", instant); | |||||
IPTracker::clean_ips(db_pool.get().unwrap()) | |||||
}); | |||||
tokio::spawn(clean_ips_task); | |||||
} | |||||
{ | |||||
let db_pool = db_pool.clone(); | |||||
let server = tcp | |||||
.incoming() | |||||
.map_err(|e| println!("Failed to accept socket; error = {:?}", e)) | |||||
.for_each(move |tcp| { | |||||
let db_pool = db_pool.clone(); | |||||
let mut conn_state = ConnState::new(tcp.peer_addr().unwrap()); | |||||
let tls_accept = tls_acceptor | |||||
.accept(tcp) | |||||
.and_then(|tls| { | |||||
let framed = MsgCodec::new().framed(tls); | |||||
let (mut writer, reader) = framed.split(); | |||||
let processor = reader | |||||
.for_each(move |data| { | |||||
let db_conn = db_pool.get().unwrap(); | |||||
if let Some(msg) = IPTracker::is_banned(&db_conn, &conn_state, BanType::Connect) { | |||||
writer.start_send(Msg::Error { msg })?; | |||||
writer.poll_complete()?; | |||||
return Ok(()); | |||||
} | |||||
let ret = handle_data( | |||||
&mut conn_state, | |||||
&mut writer, | |||||
data, | |||||
db_conn, | |||||
); | |||||
writer.poll_complete()?; | |||||
ret | |||||
}) | |||||
.and_then(|()| { | |||||
println!("Socket received FIN packet and closed connection"); | |||||
Ok(()) | |||||
}) | |||||
.or_else(|err| { | |||||
println!("Socket closed with error: {:?}", err); | |||||
Err(err) | |||||
}) | |||||
.then(|result| { | |||||
println!("Socket closed with result: {:?}", result); | |||||
Ok(()) | |||||
}); | |||||
tokio::spawn(processor); | |||||
Ok(()) | |||||
}) | |||||
.map_err(|err| { | |||||
println!("TLS accept error: {:?}", err); | |||||
}); | |||||
tokio::spawn(tls_accept); | |||||
Ok(()) | |||||
}) | |||||
.map_err(|err| { | |||||
println!("server error {:?}", err); | |||||
}); | |||||
tokio::spawn(server); | |||||
} | |||||
Ok(()) | |||||
} |
Wildfire Games · Phabricator