Refactoring move methods to separate packages add optional arguments

This commit is contained in:
Michal Szczepanski 2019-09-13 02:32:41 +02:00
parent 273e2fe3c5
commit e09b821c37
8 changed files with 192 additions and 104 deletions

2
.gitignore vendored
View File

@ -1 +1,3 @@
*.iml
*.log
main

View File

@ -160,7 +160,8 @@ block:
- '.*apenterprise.io.*'
- '.*.netdna-ssl.com.*'
- '.*.demandbase.com.*'
- '.*.wp.com.*'
- '.stats.wp.com.*'
- '.pixel.wp.com.*'
- '.*.clicktale.net.*'
- '.*.report-uri.com.*'
- '.*.algolia.com.*'

136
main.go
View File

@ -1,119 +1,50 @@
package main
import (
"./proxy"
"crypto/tls"
"gopkg.in/yaml.v2"
"fmt"
"io"
"io/ioutil"
"log"
"net"
"github.com/akamensky/argparse"
"net/http"
"regexp"
"time"
"os"
)
/*Based on
https://github.com/bechurch/reverse-proxy-demo
https://github.com/txn2/p3y/blob/master/p3y.go
https://medium.com/@mlowicki/http-s-proxy-in-golang-in-less-than-100-lines-of-code-6a51c2f2c38c
*/
type conf struct {
Entries []string `yaml:"block"`
}
//https
func handleTunneling(res http.ResponseWriter, req *http.Request) {
dest_conn, err := net.DialTimeout("tcp", req.Host, 10*time.Second)
func readArgs() (string, string, string) {
parser := argparse.NewParser("blocking-http-proxy", "HTTP/S proxy that blocks")
host := parser.String("", "host", &argparse.Options{Help: "host", Required: false})
port := parser.String("", "port", &argparse.Options{Help: "port", Required: false})
blockFile := parser.String("", "block", &argparse.Options{Help: "YAML block file", Required: false})
err := parser.Parse(os.Args)
if err != nil {
http.Error(res, err.Error(), http.StatusServiceUnavailable)
return
fmt.Println(parser.Usage(err))
}
res.WriteHeader(http.StatusOK)
hijacker, ok := res.(http.Hijacker)
if !ok {
http.Error(res, "Hijacking not supported", http.StatusInternalServerError)
return
if *host == "" {
*host = "0.0.0.0"
}
client_conn, _, err := hijacker.Hijack()
if err != nil {
http.Error(res, err.Error(), http.StatusServiceUnavailable)
if *port == "" {
*port = "11666"
}
go transfer(dest_conn, client_conn)
go transfer(client_conn, dest_conn)
}
func transfer(destination io.WriteCloser, source io.ReadCloser) {
defer destination.Close()
defer source.Close()
io.Copy(destination, source)
}
// http
func copyHeader(dst, src http.Header) {
for k, vv := range src {
for _, v := range vv {
dst.Add(k, v)
}
if *blockFile == "" {
*blockFile = "block.yaml"
}
}
func handleRequestCustom(res http.ResponseWriter, req *http.Request) {
transport := http.DefaultTransport
out, err := transport.RoundTrip(req)
if err != nil {
http.Error(res, err.Error(), http.StatusServiceUnavailable)
return
}
copyHeader(res.Header(), out.Header)
res.WriteHeader(out.StatusCode)
_, err = io.Copy(res, out.Body)
if err != nil {
http.Error(res, err.Error(), http.StatusInternalServerError)
return
}
err = out.Body.Close()
if err != nil {
http.Error(res, err.Error(), http.StatusInternalServerError)
return
}
}
func loadBlockedList(filename string) []regexp.Regexp {
var c conf
yamlFile, err := ioutil.ReadFile(filename)
if err != nil {
log.Printf("yamlFile.Get err #%v ", err)
}
err = yaml.Unmarshal(yamlFile, &c)
regexps := []regexp.Regexp{}
for _, condition := range c.Entries {
//log.Printf("%s", condition)
r := regexp.MustCompile(condition)
regexps = append(regexps, *r)
}
return regexps
}
func shouldBlock(regList []regexp.Regexp, url string) bool {
for _, condition := range regList {
if condition.MatchString(url) {
return true
}
}
return false
return *host, *port, *blockFile
}
func main() {
port := "0.0.0.0:11666"
log.Printf("Server will run on: %s\n", port)
http.HandleFunc("/", handleRequestCustom)
regList := loadBlockedList("block.yaml")
logger := proxy.NewLogger()
blockedLogger := proxy.NewFileLogger("block.log")
allowLogger := proxy.NewFileLogger("allow.log")
host, port, blockFile := readArgs()
address := host+":"+port
c := proxy.NewConfig()
fmt.Println(len(os.Args), os.Args)
logger.Printf("Server will run on: %s\n", address)
regList := c.LoadBlockedList(blockFile)
server := &http.Server{
Addr: port,
Addr: address,
Handler: http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) {
if shouldBlock(regList, req.Host) {
//log.Printf("Blocked %s\n", req.Host)
if proxy.ShouldBlock(regList, req.Host) {
blockedLogger.WriteLine("Blocked %s\n", false, req.Host)
if wr, ok := res.(http.Hijacker); ok {
conn, _, err := wr.Hijack()
if err != nil {
@ -123,11 +54,11 @@ func main() {
}
} else {
if req.Method == http.MethodConnect {
log.Printf("proxy_url: %s\n", req.Host)
handleTunneling(res, req)
allowLogger.WriteLine("proxy_url: %s\n", true, req.Host)
proxy.HandleTunneling(res, req)
} else {
log.Printf("proxy_url: %s%s\n", req.Host, req.RequestURI)
handleRequestCustom(res, req)
allowLogger.WriteLine("proxy_url: %s %s\n", true, req.Host, req.RequestURI)
proxy.HandleRequestCustom(res, req)
}
}
@ -135,6 +66,5 @@ func main() {
// Disable HTTP/2.
TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)),
}
server.ListenAndServe()
}

12
proxy/block.go Normal file
View File

@ -0,0 +1,12 @@
package proxy
import "regexp"
func ShouldBlock(regList []regexp.Regexp, url string) bool {
for _, condition := range regList {
if condition.MatchString(url) {
return true
}
}
return false
}

31
proxy/config.go Normal file
View File

@ -0,0 +1,31 @@
package proxy
import (
"io/ioutil"
"gopkg.in/yaml.v2"
"regexp"
)
type Config struct {
Entries []string `yaml:"block"`
log Logger
}
func NewConfig() *Config {
return &Config{Entries: nil, log:Logger{}}
}
func (c *Config) LoadBlockedList(filename string) []regexp.Regexp {
yamlFile, err := ioutil.ReadFile(filename)
if err != nil {
c.log.Printf("yamlFile.Get err #%v ", err)
}
err = yaml.Unmarshal(yamlFile, &c)
regexps := []regexp.Regexp{}
for _, condition := range c.Entries {
//log.Printf("%s", condition)
r := regexp.MustCompile(condition)
regexps = append(regexps, *r)
}
return regexps
}

35
proxy/http.go Normal file
View File

@ -0,0 +1,35 @@
package proxy
import (
"io"
"net/http"
)
func copyHeader(dst, src http.Header) {
for k, vv := range src {
for _, v := range vv {
dst.Add(k, v)
}
}
}
func HandleRequestCustom(res http.ResponseWriter, req *http.Request) {
transport := http.DefaultTransport
out, err := transport.RoundTrip(req)
if err != nil {
http.Error(res, err.Error(), http.StatusServiceUnavailable)
return
}
copyHeader(res.Header(), out.Header)
res.WriteHeader(out.StatusCode)
_, err = io.Copy(res, out.Body)
if err != nil {
http.Error(res, err.Error(), http.StatusInternalServerError)
return
}
err = out.Body.Close()
if err != nil {
http.Error(res, err.Error(), http.StatusInternalServerError)
return
}
}

33
proxy/https.go Normal file
View File

@ -0,0 +1,33 @@
package proxy
import (
"io"
"net"
"net/http"
"time"
)
func HandleTunneling(res http.ResponseWriter, req *http.Request) {
dest_conn, err := net.DialTimeout("tcp", req.Host, 10*time.Second)
if err != nil {
http.Error(res, err.Error(), http.StatusServiceUnavailable)
return
}
res.WriteHeader(http.StatusOK)
hijacker, ok := res.(http.Hijacker)
if !ok {
http.Error(res, "Hijacking not supported", http.StatusInternalServerError)
return
}
client_conn, _, err := hijacker.Hijack()
if err != nil {
http.Error(res, err.Error(), http.StatusServiceUnavailable)
}
go transfer(dest_conn, client_conn)
go transfer(client_conn, dest_conn)
}
func transfer(destination io.WriteCloser, source io.ReadCloser) {
defer destination.Close()
defer source.Close()
io.Copy(destination, source)
}

44
proxy/logger.go Normal file
View File

@ -0,0 +1,44 @@
package proxy
import (
"fmt"
"log"
"os"
)
type Logger struct{
}
type FileLogger struct {
Logger
filename string
}
func NewLogger() *Logger {
return &Logger{}
}
func NewFileLogger(filename string) *FileLogger {
if _, err := os.Stat(filename); os.IsNotExist(err) {
os.OpenFile(filename, os.O_RDWR|os.O_CREATE|os.O_EXCL, 0666)
}
return &FileLogger{filename: filename}
}
func (l *FileLogger) WriteLine(msg string, print bool, v ...interface{}) {
f, err := os.OpenFile(l.filename, os.O_APPEND|os.O_WRONLY, 0600)
if err != nil {
panic(err)
}
line := fmt.Sprintf(msg, v...)
if print {
log.Printf(msg, v...)
}
f.WriteString(line)
defer f.Close()
}
func (l *Logger) Printf(msg string, v ...interface{}) {
log.Printf(msg, v...)
}