Refactoring move methods to separate packages add optional arguments

This commit is contained in:
Michal Szczepanski 2019-09-13 02:32:41 +02:00
parent 273e2fe3c5
commit e09b821c37
8 changed files with 192 additions and 104 deletions

2
.gitignore vendored

@ -1 +1,3 @@
*.iml *.iml
*.log
main

@ -160,7 +160,8 @@ block:
- '.*apenterprise.io.*' - '.*apenterprise.io.*'
- '.*.netdna-ssl.com.*' - '.*.netdna-ssl.com.*'
- '.*.demandbase.com.*' - '.*.demandbase.com.*'
- '.*.wp.com.*' - '.stats.wp.com.*'
- '.pixel.wp.com.*'
- '.*.clicktale.net.*' - '.*.clicktale.net.*'
- '.*.report-uri.com.*' - '.*.report-uri.com.*'
- '.*.algolia.com.*' - '.*.algolia.com.*'

136
main.go

@ -1,119 +1,50 @@
package main package main
import ( import (
"./proxy"
"crypto/tls" "crypto/tls"
"gopkg.in/yaml.v2"
"fmt" "fmt"
"io" "github.com/akamensky/argparse"
"io/ioutil"
"log"
"net"
"net/http" "net/http"
"regexp" "os"
"time"
) )
/*Based on
https://github.com/bechurch/reverse-proxy-demo
https://github.com/txn2/p3y/blob/master/p3y.go
https://medium.com/@mlowicki/http-s-proxy-in-golang-in-less-than-100-lines-of-code-6a51c2f2c38c
*/
type conf struct { func readArgs() (string, string, string) {
Entries []string `yaml:"block"` parser := argparse.NewParser("blocking-http-proxy", "HTTP/S proxy that blocks")
} host := parser.String("", "host", &argparse.Options{Help: "host", Required: false})
port := parser.String("", "port", &argparse.Options{Help: "port", Required: false})
//https blockFile := parser.String("", "block", &argparse.Options{Help: "YAML block file", Required: false})
func handleTunneling(res http.ResponseWriter, req *http.Request) { err := parser.Parse(os.Args)
dest_conn, err := net.DialTimeout("tcp", req.Host, 10*time.Second)
if err != nil { if err != nil {
http.Error(res, err.Error(), http.StatusServiceUnavailable) fmt.Println(parser.Usage(err))
return
} }
res.WriteHeader(http.StatusOK) if *host == "" {
hijacker, ok := res.(http.Hijacker) *host = "0.0.0.0"
if !ok {
http.Error(res, "Hijacking not supported", http.StatusInternalServerError)
return
} }
client_conn, _, err := hijacker.Hijack() if *port == "" {
if err != nil { *port = "11666"
http.Error(res, err.Error(), http.StatusServiceUnavailable)
} }
go transfer(dest_conn, client_conn) if *blockFile == "" {
go transfer(client_conn, dest_conn) *blockFile = "block.yaml"
}
func transfer(destination io.WriteCloser, source io.ReadCloser) {
defer destination.Close()
defer source.Close()
io.Copy(destination, source)
}
// http
func copyHeader(dst, src http.Header) {
for k, vv := range src {
for _, v := range vv {
dst.Add(k, v)
} }
} return *host, *port, *blockFile
}
func handleRequestCustom(res http.ResponseWriter, req *http.Request) {
transport := http.DefaultTransport
out, err := transport.RoundTrip(req)
if err != nil {
http.Error(res, err.Error(), http.StatusServiceUnavailable)
return
}
copyHeader(res.Header(), out.Header)
res.WriteHeader(out.StatusCode)
_, err = io.Copy(res, out.Body)
if err != nil {
http.Error(res, err.Error(), http.StatusInternalServerError)
return
}
err = out.Body.Close()
if err != nil {
http.Error(res, err.Error(), http.StatusInternalServerError)
return
}
}
func loadBlockedList(filename string) []regexp.Regexp {
var c conf
yamlFile, err := ioutil.ReadFile(filename)
if err != nil {
log.Printf("yamlFile.Get err #%v ", err)
}
err = yaml.Unmarshal(yamlFile, &c)
regexps := []regexp.Regexp{}
for _, condition := range c.Entries {
//log.Printf("%s", condition)
r := regexp.MustCompile(condition)
regexps = append(regexps, *r)
}
return regexps
}
func shouldBlock(regList []regexp.Regexp, url string) bool {
for _, condition := range regList {
if condition.MatchString(url) {
return true
}
}
return false
} }
func main() { func main() {
port := "0.0.0.0:11666" logger := proxy.NewLogger()
log.Printf("Server will run on: %s\n", port) blockedLogger := proxy.NewFileLogger("block.log")
http.HandleFunc("/", handleRequestCustom) allowLogger := proxy.NewFileLogger("allow.log")
regList := loadBlockedList("block.yaml") host, port, blockFile := readArgs()
address := host+":"+port
c := proxy.NewConfig()
fmt.Println(len(os.Args), os.Args)
logger.Printf("Server will run on: %s\n", address)
regList := c.LoadBlockedList(blockFile)
server := &http.Server{ server := &http.Server{
Addr: port, Addr: address,
Handler: http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) { Handler: http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) {
if shouldBlock(regList, req.Host) { if proxy.ShouldBlock(regList, req.Host) {
//log.Printf("Blocked %s\n", req.Host) blockedLogger.WriteLine("Blocked %s\n", false, req.Host)
if wr, ok := res.(http.Hijacker); ok { if wr, ok := res.(http.Hijacker); ok {
conn, _, err := wr.Hijack() conn, _, err := wr.Hijack()
if err != nil { if err != nil {
@ -123,11 +54,11 @@ func main() {
} }
} else { } else {
if req.Method == http.MethodConnect { if req.Method == http.MethodConnect {
log.Printf("proxy_url: %s\n", req.Host) allowLogger.WriteLine("proxy_url: %s\n", true, req.Host)
handleTunneling(res, req) proxy.HandleTunneling(res, req)
} else { } else {
log.Printf("proxy_url: %s%s\n", req.Host, req.RequestURI) allowLogger.WriteLine("proxy_url: %s %s\n", true, req.Host, req.RequestURI)
handleRequestCustom(res, req) proxy.HandleRequestCustom(res, req)
} }
} }
@ -135,6 +66,5 @@ func main() {
// Disable HTTP/2. // Disable HTTP/2.
TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)), TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)),
} }
server.ListenAndServe() server.ListenAndServe()
} }

12
proxy/block.go Normal file

@ -0,0 +1,12 @@
package proxy
import "regexp"
func ShouldBlock(regList []regexp.Regexp, url string) bool {
for _, condition := range regList {
if condition.MatchString(url) {
return true
}
}
return false
}

31
proxy/config.go Normal file

@ -0,0 +1,31 @@
package proxy
import (
"io/ioutil"
"gopkg.in/yaml.v2"
"regexp"
)
type Config struct {
Entries []string `yaml:"block"`
log Logger
}
func NewConfig() *Config {
return &Config{Entries: nil, log:Logger{}}
}
func (c *Config) LoadBlockedList(filename string) []regexp.Regexp {
yamlFile, err := ioutil.ReadFile(filename)
if err != nil {
c.log.Printf("yamlFile.Get err #%v ", err)
}
err = yaml.Unmarshal(yamlFile, &c)
regexps := []regexp.Regexp{}
for _, condition := range c.Entries {
//log.Printf("%s", condition)
r := regexp.MustCompile(condition)
regexps = append(regexps, *r)
}
return regexps
}

35
proxy/http.go Normal file

@ -0,0 +1,35 @@
package proxy
import (
"io"
"net/http"
)
func copyHeader(dst, src http.Header) {
for k, vv := range src {
for _, v := range vv {
dst.Add(k, v)
}
}
}
func HandleRequestCustom(res http.ResponseWriter, req *http.Request) {
transport := http.DefaultTransport
out, err := transport.RoundTrip(req)
if err != nil {
http.Error(res, err.Error(), http.StatusServiceUnavailable)
return
}
copyHeader(res.Header(), out.Header)
res.WriteHeader(out.StatusCode)
_, err = io.Copy(res, out.Body)
if err != nil {
http.Error(res, err.Error(), http.StatusInternalServerError)
return
}
err = out.Body.Close()
if err != nil {
http.Error(res, err.Error(), http.StatusInternalServerError)
return
}
}

33
proxy/https.go Normal file

@ -0,0 +1,33 @@
package proxy
import (
"io"
"net"
"net/http"
"time"
)
func HandleTunneling(res http.ResponseWriter, req *http.Request) {
dest_conn, err := net.DialTimeout("tcp", req.Host, 10*time.Second)
if err != nil {
http.Error(res, err.Error(), http.StatusServiceUnavailable)
return
}
res.WriteHeader(http.StatusOK)
hijacker, ok := res.(http.Hijacker)
if !ok {
http.Error(res, "Hijacking not supported", http.StatusInternalServerError)
return
}
client_conn, _, err := hijacker.Hijack()
if err != nil {
http.Error(res, err.Error(), http.StatusServiceUnavailable)
}
go transfer(dest_conn, client_conn)
go transfer(client_conn, dest_conn)
}
func transfer(destination io.WriteCloser, source io.ReadCloser) {
defer destination.Close()
defer source.Close()
io.Copy(destination, source)
}

44
proxy/logger.go Normal file

@ -0,0 +1,44 @@
package proxy
import (
"fmt"
"log"
"os"
)
type Logger struct{
}
type FileLogger struct {
Logger
filename string
}
func NewLogger() *Logger {
return &Logger{}
}
func NewFileLogger(filename string) *FileLogger {
if _, err := os.Stat(filename); os.IsNotExist(err) {
os.OpenFile(filename, os.O_RDWR|os.O_CREATE|os.O_EXCL, 0666)
}
return &FileLogger{filename: filename}
}
func (l *FileLogger) WriteLine(msg string, print bool, v ...interface{}) {
f, err := os.OpenFile(l.filename, os.O_APPEND|os.O_WRONLY, 0600)
if err != nil {
panic(err)
}
line := fmt.Sprintf(msg, v...)
if print {
log.Printf(msg, v...)
}
f.WriteString(line)
defer f.Close()
}
func (l *Logger) Printf(msg string, v ...interface{}) {
log.Printf(msg, v...)
}