improve server relay and client
This commit is contained in:
parent
9984583dd2
commit
fd0e2afddb
6 changed files with 461 additions and 109 deletions
|
|
@ -1,34 +1,62 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
"sync"
|
||||
|
||||
"gossip/pkg/protocol"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
_ "github.com/jackc/pgx/v5/stdlib"
|
||||
"github.com/joho/godotenv"
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
db *sql.DB
|
||||
clients map[string]*websocket.Conn
|
||||
mu sync.Mutex
|
||||
upgrader websocket.Upgrader
|
||||
}
|
||||
|
||||
func main() {
|
||||
if err := godotenv.Load(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
dbuser := os.Getenv("DBUSER")
|
||||
dbname := os.Getenv("DBNAME")
|
||||
dbport := os.Getenv("DBPORT")
|
||||
dbhost := os.Getenv("DBHOST")
|
||||
dbpass := os.Getenv("DBPASS")
|
||||
|
||||
connStr := fmt.Sprintf("postgres://%s:%s@%s:%s/%s?sslmode=require", dbuser, dbpass, dbhost, dbport, dbname)
|
||||
|
||||
db, err := sql.Open("pgx", connStr)
|
||||
if err != nil {
|
||||
log.Fatal("Failed to open DB:", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
if err := db.Ping(); err != nil {
|
||||
log.Fatal("Failed to connect to DB:", err)
|
||||
}
|
||||
|
||||
srv := &Server{
|
||||
db: db,
|
||||
clients: make(map[string]*websocket.Conn),
|
||||
upgrader: websocket.Upgrader{
|
||||
ReadBufferSize: 1024,
|
||||
WriteBufferSize: 1024,
|
||||
CheckOrigin: func(r *http.Request) bool { return true },
|
||||
CheckOrigin: func(r *http.Request) bool { return true },
|
||||
},
|
||||
}
|
||||
|
||||
http.HandleFunc("/ws", srv.handleWS)
|
||||
log.Println("Relay Server listening on :8080")
|
||||
log.Println("Gossip Relay Server listening on :8080")
|
||||
log.Fatal(http.ListenAndServe(":8080", nil))
|
||||
}
|
||||
|
||||
|
|
@ -54,22 +82,47 @@ func (s *Server) handleWS(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
switch msg.Type {
|
||||
case "login":
|
||||
s.mu.Lock()
|
||||
s.clients[msg.Sender] = conn
|
||||
s.mu.Unlock()
|
||||
case protocol.TypeLogin:
|
||||
myPubKey = msg.Sender
|
||||
log.Printf("Client connected: %s...", myPubKey[:8])
|
||||
case "msg":
|
||||
|
||||
s.mu.Lock()
|
||||
s.clients[myPubKey] = conn
|
||||
s.mu.Unlock()
|
||||
|
||||
accNum, err := s.getOrCreateUser(myPubKey)
|
||||
if err != nil {
|
||||
log.Println("DB Error:", err)
|
||||
continue
|
||||
}
|
||||
|
||||
resp := protocol.Message{
|
||||
Type: protocol.TypeIdentity,
|
||||
Target: myPubKey,
|
||||
Content: strconv.Itoa(accNum),
|
||||
}
|
||||
conn.WriteJSON(resp)
|
||||
log.Printf("User %d connected (%s...)", accNum, myPubKey[:8])
|
||||
|
||||
case protocol.TypeLookup:
|
||||
targetAcc, _ := strconv.Atoi(msg.Content)
|
||||
foundKey, ok := s.lookupKey(targetAcc)
|
||||
|
||||
resp := protocol.Message{
|
||||
Type: protocol.TypeLookupResponse,
|
||||
Target: msg.Content,
|
||||
}
|
||||
if ok {
|
||||
resp.Content = foundKey
|
||||
}
|
||||
conn.WriteJSON(resp)
|
||||
|
||||
case protocol.TypeMsg:
|
||||
s.mu.Lock()
|
||||
targetConn, ok := s.clients[msg.Target]
|
||||
s.mu.Unlock()
|
||||
|
||||
if ok {
|
||||
err = targetConn.WriteMessage(websocket.TextMessage, data)
|
||||
if err != nil {
|
||||
log.Printf("Failed to relay to %s", msg.Target[:8])
|
||||
}
|
||||
targetConn.WriteMessage(websocket.TextMessage, data)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -78,6 +131,24 @@ func (s *Server) handleWS(w http.ResponseWriter, r *http.Request) {
|
|||
s.mu.Lock()
|
||||
delete(s.clients, myPubKey)
|
||||
s.mu.Unlock()
|
||||
log.Printf("Client disconnected: %s...", myPubKey[:8])
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) getOrCreateUser(pubKey string) (int, error) {
|
||||
var accNum int
|
||||
err := s.db.QueryRow("SELECT account_number FROM users WHERE public_key=$1", pubKey).Scan(&accNum)
|
||||
if err == nil {
|
||||
return accNum, nil
|
||||
}
|
||||
err = s.db.QueryRow("INSERT INTO users (public_key) VALUES ($1) RETURNING account_number", pubKey).Scan(&accNum)
|
||||
return accNum, err
|
||||
}
|
||||
|
||||
func (s *Server) lookupKey(accNum int) (string, bool) {
|
||||
var pubKey string
|
||||
err := s.db.QueryRow("SELECT public_key FROM users WHERE account_number=$1", accNum).Scan(&pubKey)
|
||||
if err != nil {
|
||||
return "", false
|
||||
}
|
||||
return pubKey, true
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue