package dbus

import (
	"errors"
	"io"
	"os"
	"reflect"
	"strings"
	"sync"
)

var (
	systemBus     *Conn
	systemBusLck  sync.Mutex
	sessionBus    *Conn
	sessionBusLck sync.Mutex
	sessionEnvLck sync.Mutex
)

// ErrClosed is the error returned by calls on a closed connection.
var ErrClosed = errors.New("dbus: connection closed by user")

// Conn represents a connection to a message bus (usually, the system or
// session bus).
//
// Connections are either shared or private. Shared connections
// are shared between calls to the functions that return them. As a result,
// the methods Close, Auth and Hello must not be called on them.
//
// Multiple goroutines may invoke methods on a connection simultaneously.
type Conn struct {
	transport

	busObj BusObject
	unixFD bool
	uuid   string

	names *nameTracker

	serialGen *serialGenerator

	calls *callTracker

	handler Handler

	outHandler *outputHandler

	signalHandler SignalHandler

	eavesdropped    chan<- *Message
	eavesdroppedLck sync.Mutex
}

// SessionBus returns a shared connection to the session bus, connecting to it
// if not already done.
func SessionBus() (conn *Conn, err error) {
	sessionBusLck.Lock()
	defer sessionBusLck.Unlock()
	if sessionBus != nil {
		return sessionBus, nil
	}
	defer func() {
		if conn != nil {
			sessionBus = conn
		}
	}()
	conn, err = SessionBusPrivate()
	if err != nil {
		return
	}
	if err = conn.Auth(nil); err != nil {
		conn.Close()
		conn = nil
		return
	}
	if err = conn.Hello(); err != nil {
		conn.Close()
		conn = nil
	}
	return
}

func getSessionBusAddress() (string, error) {
	sessionEnvLck.Lock()
	defer sessionEnvLck.Unlock()
	address := os.Getenv("DBUS_SESSION_BUS_ADDRESS")
	if address != "" && address != "autolaunch:" {
		return address, nil
	}
	return getSessionBusPlatformAddress()
}

// SessionBusPrivate returns a new private connection to the session bus.
func SessionBusPrivate() (*Conn, error) {
	address, err := getSessionBusAddress()
	if err != nil {
		return nil, err
	}

	return Dial(address)
}

// SessionBusPrivate returns a new private connection to the session bus.
func SessionBusPrivateHandler(handler Handler, signalHandler SignalHandler) (*Conn, error) {
	address, err := getSessionBusAddress()
	if err != nil {
		return nil, err
	}
	return DialHandler(address, handler, signalHandler)
}

// SystemBus returns a shared connection to the system bus, connecting to it if
// not already done.
func SystemBus() (conn *Conn, err error) {
	systemBusLck.Lock()
	defer systemBusLck.Unlock()
	if systemBus != nil {
		return systemBus, nil
	}
	defer func() {
		if conn != nil {
			systemBus = conn
		}
	}()
	conn, err = SystemBusPrivate()
	if err != nil {
		return
	}
	if err = conn.Auth(nil); err != nil {
		conn.Close()
		conn = nil
		return
	}
	if err = conn.Hello(); err != nil {
		conn.Close()
		conn = nil
	}
	return
}

// SystemBusPrivate returns a new private connection to the system bus.
func SystemBusPrivate() (*Conn, error) {
	return Dial(getSystemBusPlatformAddress())
}

// SystemBusPrivateHandler returns a new private connection to the system bus, using the provided handlers.
func SystemBusPrivateHandler(handler Handler, signalHandler SignalHandler) (*Conn, error) {
	return DialHandler(getSystemBusPlatformAddress(), handler, signalHandler)
}

// Dial establishes a new private connection to the message bus specified by address.
func Dial(address string) (*Conn, error) {
	tr, err := getTransport(address)
	if err != nil {
		return nil, err
	}
	return newConn(tr, NewDefaultHandler(), NewDefaultSignalHandler())
}

// DialHandler establishes a new private connection to the message bus specified by address, using the supplied handlers.
func DialHandler(address string, handler Handler, signalHandler SignalHandler) (*Conn, error) {
	tr, err := getTransport(address)
	if err != nil {
		return nil, err
	}
	return newConn(tr, handler, signalHandler)
}

// NewConn creates a new private *Conn from an already established connection.
func NewConn(conn io.ReadWriteCloser) (*Conn, error) {
	return NewConnHandler(conn, NewDefaultHandler(), NewDefaultSignalHandler())
}

// NewConnHandler creates a new private *Conn from an already established connection, using the supplied handlers.
func NewConnHandler(conn io.ReadWriteCloser, handler Handler, signalHandler SignalHandler) (*Conn, error) {
	return newConn(genericTransport{conn}, handler, signalHandler)
}

// newConn creates a new *Conn from a transport.
func newConn(tr transport, handler Handler, signalHandler SignalHandler) (*Conn, error) {
	conn := new(Conn)
	conn.transport = tr
	conn.calls = newCallTracker()
	conn.handler = handler
	conn.signalHandler = signalHandler
	conn.outHandler = &outputHandler{conn: conn}
	conn.serialGen = newSerialGenerator()
	conn.names = newNameTracker()
	conn.busObj = conn.Object("org.freedesktop.DBus", "/org/freedesktop/DBus")
	return conn, nil
}

// BusObject returns the object owned by the bus daemon which handles
// administrative requests.
func (conn *Conn) BusObject() BusObject {
	return conn.busObj
}

// Close closes the connection. Any blocked operations will return with errors
// and the channels passed to Eavesdrop and Signal are closed. This method must
// not be called on shared connections.
func (conn *Conn) Close() error {
	conn.outHandler.close()
	if term, ok := conn.signalHandler.(Terminator); ok {
		term.Terminate()
	}

	if term, ok := conn.handler.(Terminator); ok {
		term.Terminate()
	}

	conn.eavesdroppedLck.Lock()
	if conn.eavesdropped != nil {
		close(conn.eavesdropped)
	}
	conn.eavesdroppedLck.Unlock()

	return conn.transport.Close()
}

// Eavesdrop causes conn to send all incoming messages to the given channel
// without further processing. Method replies, errors and signals will not be
// sent to the appropiate channels and method calls will not be handled. If nil
// is passed, the normal behaviour is restored.
//
// The caller has to make sure that ch is sufficiently buffered;
// if a message arrives when a write to ch is not possible, the message is
// discarded.
func (conn *Conn) Eavesdrop(ch chan<- *Message) {
	conn.eavesdroppedLck.Lock()
	conn.eavesdropped = ch
	conn.eavesdroppedLck.Unlock()
}

// getSerial returns an unused serial.
func (conn *Conn) getSerial() uint32 {
	return conn.serialGen.getSerial()
}

// Hello sends the initial org.freedesktop.DBus.Hello call. This method must be
// called after authentication, but before sending any other messages to the
// bus. Hello must not be called for shared connections.
func (conn *Conn) Hello() error {
	var s string
	err := conn.busObj.Call("org.freedesktop.DBus.Hello", 0).Store(&s)
	if err != nil {
		return err
	}
	conn.names.acquireUniqueConnectionName(s)
	return nil
}

// inWorker runs in an own goroutine, reading incoming messages from the
// transport and dispatching them appropiately.
func (conn *Conn) inWorker() {
	for {
		msg, err := conn.ReadMessage()
		if err != nil {
			if _, ok := err.(InvalidMessageError); !ok {
				// Some read error occured (usually EOF); we can't really do
				// anything but to shut down all stuff and returns errors to all
				// pending replies.
				conn.Close()
				conn.calls.finalizeAllWithError(err)
				return
			}
			// invalid messages are ignored
			continue
		}
		conn.eavesdroppedLck.Lock()
		if conn.eavesdropped != nil {
			select {
			case conn.eavesdropped <- msg:
			default:
			}
			conn.eavesdroppedLck.Unlock()
			continue
		}
		conn.eavesdroppedLck.Unlock()
		dest, _ := msg.Headers[FieldDestination].value.(string)
		found := dest == "" ||
			!conn.names.uniqueNameIsKnown() ||
			conn.names.isKnownName(dest)
		if !found {
			// Eavesdropped a message, but no channel for it is registered.
			// Ignore it.
			continue
		}
		switch msg.Type {
		case TypeError:
			conn.serialGen.retireSerial(conn.calls.handleDBusError(msg))
		case TypeMethodReply:
			conn.serialGen.retireSerial(conn.calls.handleReply(msg))
		case TypeSignal:
			conn.handleSignal(msg)
		case TypeMethodCall:
			go conn.handleCall(msg)
		}

	}
}

func (conn *Conn) handleSignal(msg *Message) {
	iface := msg.Headers[FieldInterface].value.(string)
	member := msg.Headers[FieldMember].value.(string)
	// as per http://dbus.freedesktop.org/doc/dbus-specification.html ,
	// sender is optional for signals.
	sender, _ := msg.Headers[FieldSender].value.(string)
	if iface == "org.freedesktop.DBus" && sender == "org.freedesktop.DBus" {
		if member == "NameLost" {
			// If we lost the name on the bus, remove it from our
			// tracking list.
			name, ok := msg.Body[0].(string)
			if !ok {
				panic("Unable to read the lost name")
			}
			conn.names.loseName(name)
		} else if member == "NameAcquired" {
			// If we acquired the name on the bus, add it to our
			// tracking list.
			name, ok := msg.Body[0].(string)
			if !ok {
				panic("Unable to read the acquired name")
			}
			conn.names.acquireName(name)
		}
	}
	signal := &Signal{
		Sender: sender,
		Path:   msg.Headers[FieldPath].value.(ObjectPath),
		Name:   iface + "." + member,
		Body:   msg.Body,
	}
	conn.signalHandler.DeliverSignal(iface, member, signal)
}

// Names returns the list of all names that are currently owned by this
// connection. The slice is always at least one element long, the first element
// being the unique name of the connection.
func (conn *Conn) Names() []string {
	return conn.names.listKnownNames()
}

// Object returns the object identified by the given destination name and path.
func (conn *Conn) Object(dest string, path ObjectPath) BusObject {
	return &Object{conn, dest, path}
}

// outWorker runs in an own goroutine, encoding and sending messages that are
// sent to conn.out.
func (conn *Conn) sendMessage(msg *Message) {
	conn.sendMessageAndIfClosed(msg, func() {})
}

func (conn *Conn) sendMessageAndIfClosed(msg *Message, ifClosed func()) {
	err := conn.outHandler.sendAndIfClosed(msg, ifClosed)
	conn.calls.handleSendError(msg, err)
	if err != nil {
		conn.serialGen.retireSerial(msg.serial)
	} else if msg.Type != TypeMethodCall {
		conn.serialGen.retireSerial(msg.serial)
	}
}

// Send sends the given message to the message bus. You usually don't need to
// use this; use the higher-level equivalents (Call / Go, Emit and Export)
// instead. If msg is a method call and NoReplyExpected is not set, a non-nil
// call is returned and the same value is sent to ch (which must be buffered)
// once the call is complete. Otherwise, ch is ignored and a Call structure is
// returned of which only the Err member is valid.
func (conn *Conn) Send(msg *Message, ch chan *Call) *Call {
	var call *Call

	msg.serial = conn.getSerial()
	if msg.Type == TypeMethodCall && msg.Flags&FlagNoReplyExpected == 0 {
		if ch == nil {
			ch = make(chan *Call, 5)
		} else if cap(ch) == 0 {
			panic("dbus: unbuffered channel passed to (*Conn).Send")
		}
		call = new(Call)
		call.Destination, _ = msg.Headers[FieldDestination].value.(string)
		call.Path, _ = msg.Headers[FieldPath].value.(ObjectPath)
		iface, _ := msg.Headers[FieldInterface].value.(string)
		member, _ := msg.Headers[FieldMember].value.(string)
		call.Method = iface + "." + member
		call.Args = msg.Body
		call.Done = ch
		conn.calls.track(msg.serial, call)
		conn.sendMessageAndIfClosed(msg, func() {
			call.Err = ErrClosed
			call.Done <- call
		})
	} else {
		call = &Call{Err: nil}
		conn.sendMessageAndIfClosed(msg, func() {
			call = &Call{Err: ErrClosed}
		})
	}
	return call
}

// sendError creates an error message corresponding to the parameters and sends
// it to conn.out.
func (conn *Conn) sendError(err error, dest string, serial uint32) {
	var e *Error
	switch em := err.(type) {
	case Error:
		e = &em
	case *Error:
		e = em
	case DBusError:
		name, body := em.DBusError()
		e = NewError(name, body)
	default:
		e = MakeFailedError(err)
	}
	msg := new(Message)
	msg.Type = TypeError
	msg.serial = conn.getSerial()
	msg.Headers = make(map[HeaderField]Variant)
	if dest != "" {
		msg.Headers[FieldDestination] = MakeVariant(dest)
	}
	msg.Headers[FieldErrorName] = MakeVariant(e.Name)
	msg.Headers[FieldReplySerial] = MakeVariant(serial)
	msg.Body = e.Body
	if len(e.Body) > 0 {
		msg.Headers[FieldSignature] = MakeVariant(SignatureOf(e.Body...))
	}
	conn.sendMessage(msg)
}

// sendReply creates a method reply message corresponding to the parameters and
// sends it to conn.out.
func (conn *Conn) sendReply(dest string, serial uint32, values ...interface{}) {
	msg := new(Message)
	msg.Type = TypeMethodReply
	msg.serial = conn.getSerial()
	msg.Headers = make(map[HeaderField]Variant)
	if dest != "" {
		msg.Headers[FieldDestination] = MakeVariant(dest)
	}
	msg.Headers[FieldReplySerial] = MakeVariant(serial)
	msg.Body = values
	if len(values) > 0 {
		msg.Headers[FieldSignature] = MakeVariant(SignatureOf(values...))
	}
	conn.sendMessage(msg)
}

func (conn *Conn) defaultSignalAction(fn func(h *defaultSignalHandler, ch chan<- *Signal), ch chan<- *Signal) {
	if !isDefaultSignalHandler(conn.signalHandler) {
		return
	}
	handler := conn.signalHandler.(*defaultSignalHandler)
	fn(handler, ch)
}

// Signal registers the given channel to be passed all received signal messages.
// The caller has to make sure that ch is sufficiently buffered; if a message
// arrives when a write to c is not possible, it is discarded.
//
// Multiple of these channels can be registered at the same time.
//
// These channels are "overwritten" by Eavesdrop; i.e., if there currently is a
// channel for eavesdropped messages, this channel receives all signals, and
// none of the channels passed to Signal will receive any signals.
func (conn *Conn) Signal(ch chan<- *Signal) {
	conn.defaultSignalAction((*defaultSignalHandler).addSignal, ch)
}

// RemoveSignal removes the given channel from the list of the registered channels.
func (conn *Conn) RemoveSignal(ch chan<- *Signal) {
	conn.defaultSignalAction((*defaultSignalHandler).removeSignal, ch)
}

// SupportsUnixFDs returns whether the underlying transport supports passing of
// unix file descriptors. If this is false, method calls containing unix file
// descriptors will return an error and emitted signals containing them will
// not be sent.
func (conn *Conn) SupportsUnixFDs() bool {
	return conn.unixFD
}

// Error represents a D-Bus message of type Error.
type Error struct {
	Name string
	Body []interface{}
}

func NewError(name string, body []interface{}) *Error {
	return &Error{name, body}
}

func (e Error) Error() string {
	if len(e.Body) >= 1 {
		s, ok := e.Body[0].(string)
		if ok {
			return s
		}
	}
	return e.Name
}

// Signal represents a D-Bus message of type Signal. The name member is given in
// "interface.member" notation, e.g. org.freedesktop.D-Bus.NameLost.
type Signal struct {
	Sender string
	Path   ObjectPath
	Name   string
	Body   []interface{}
}

// transport is a D-Bus transport.
type transport interface {
	// Read and Write raw data (for example, for the authentication protocol).
	io.ReadWriteCloser

	// Send the initial null byte used for the EXTERNAL mechanism.
	SendNullByte() error

	// Returns whether this transport supports passing Unix FDs.
	SupportsUnixFDs() bool

	// Signal the transport that Unix FD passing is enabled for this connection.
	EnableUnixFDs()

	// Read / send a message, handling things like Unix FDs.
	ReadMessage() (*Message, error)
	SendMessage(*Message) error
}

var (
	transports = make(map[string]func(string) (transport, error))
)

func getTransport(address string) (transport, error) {
	var err error
	var t transport

	addresses := strings.Split(address, ";")
	for _, v := range addresses {
		i := strings.IndexRune(v, ':')
		if i == -1 {
			err = errors.New("dbus: invalid bus address (no transport)")
			continue
		}
		f := transports[v[:i]]
		if f == nil {
			err = errors.New("dbus: invalid bus address (invalid or unsupported transport)")
			continue
		}
		t, err = f(v[i+1:])
		if err == nil {
			return t, nil
		}
	}
	return nil, err
}

// dereferenceAll returns a slice that, assuming that vs is a slice of pointers
// of arbitrary types, containes the values that are obtained from dereferencing
// all elements in vs.
func dereferenceAll(vs []interface{}) []interface{} {
	for i := range vs {
		v := reflect.ValueOf(vs[i])
		v = v.Elem()
		vs[i] = v.Interface()
	}
	return vs
}

// getKey gets a key from a the list of keys. Returns "" on error / not found...
func getKey(s, key string) string {
	for _, keyEqualsValue := range strings.Split(s, ",") {
		keyValue := strings.SplitN(keyEqualsValue, "=", 2)
		if len(keyValue) == 2 && keyValue[0] == key {
			return keyValue[1]
		}
	}
	return ""
}

type outputHandler struct {
	conn    *Conn
	sendLck sync.Mutex
	closed  struct {
		isClosed bool
		lck      sync.RWMutex
	}
}

func (h *outputHandler) sendAndIfClosed(msg *Message, ifClosed func()) error {
	h.closed.lck.RLock()
	defer h.closed.lck.RUnlock()
	if h.closed.isClosed {
		ifClosed()
		return nil
	}
	h.sendLck.Lock()
	defer h.sendLck.Unlock()
	return h.conn.SendMessage(msg)
}

func (h *outputHandler) close() {
	h.closed.lck.Lock()
	defer h.closed.lck.Unlock()
	h.closed.isClosed = true
}

type serialGenerator struct {
	lck        sync.Mutex
	nextSerial uint32
	serialUsed map[uint32]bool
}

func newSerialGenerator() *serialGenerator {
	return &serialGenerator{
		serialUsed: map[uint32]bool{0: true},
		nextSerial: 1,
	}
}

func (gen *serialGenerator) getSerial() uint32 {
	gen.lck.Lock()
	defer gen.lck.Unlock()
	n := gen.nextSerial
	for gen.serialUsed[n] {
		n++
	}
	gen.serialUsed[n] = true
	gen.nextSerial = n + 1
	return n
}

func (gen *serialGenerator) retireSerial(serial uint32) {
	gen.lck.Lock()
	defer gen.lck.Unlock()
	delete(gen.serialUsed, serial)
}

type nameTracker struct {
	lck    sync.RWMutex
	unique string
	names  map[string]struct{}
}

func newNameTracker() *nameTracker {
	return &nameTracker{names: map[string]struct{}{}}
}
func (tracker *nameTracker) acquireUniqueConnectionName(name string) {
	tracker.lck.Lock()
	defer tracker.lck.Unlock()
	tracker.unique = name
}
func (tracker *nameTracker) acquireName(name string) {
	tracker.lck.Lock()
	defer tracker.lck.Unlock()
	tracker.names[name] = struct{}{}
}
func (tracker *nameTracker) loseName(name string) {
	tracker.lck.Lock()
	defer tracker.lck.Unlock()
	delete(tracker.names, name)
}

func (tracker *nameTracker) uniqueNameIsKnown() bool {
	tracker.lck.RLock()
	defer tracker.lck.RUnlock()
	return tracker.unique != ""
}
func (tracker *nameTracker) isKnownName(name string) bool {
	tracker.lck.RLock()
	defer tracker.lck.RUnlock()
	_, ok := tracker.names[name]
	return ok || name == tracker.unique
}
func (tracker *nameTracker) listKnownNames() []string {
	tracker.lck.RLock()
	defer tracker.lck.RUnlock()
	out := make([]string, 0, len(tracker.names)+1)
	out = append(out, tracker.unique)
	for k := range tracker.names {
		out = append(out, k)
	}
	return out
}

type callTracker struct {
	calls map[uint32]*Call
	lck   sync.RWMutex
}

func newCallTracker() *callTracker {
	return &callTracker{calls: map[uint32]*Call{}}
}

func (tracker *callTracker) track(sn uint32, call *Call) {
	tracker.lck.Lock()
	defer tracker.lck.Unlock()
	tracker.calls[sn] = call
}

func (tracker *callTracker) handleReply(msg *Message) uint32 {
	serial := msg.Headers[FieldReplySerial].value.(uint32)
	tracker.lck.RLock()
	c, ok := tracker.calls[serial]
	tracker.lck.RUnlock()
	if !ok {
		return serial
	}
	c.Body = msg.Body
	c.Done <- c
	tracker.finalize(serial)
	return serial
}

func (tracker *callTracker) handleDBusError(msg *Message) uint32 {
	serial := msg.Headers[FieldReplySerial].value.(uint32)
	tracker.lck.RLock()
	c, ok := tracker.calls[serial]
	tracker.lck.RUnlock()
	if !ok {
		return serial
	}
	name, _ := msg.Headers[FieldErrorName].value.(string)
	c.Err = Error{name, msg.Body}
	c.Done <- c
	tracker.finalize(serial)
	return serial
}

func (tracker *callTracker) handleSendError(msg *Message, err error) {
	if err == nil {
		return
	}
	tracker.lck.RLock()
	c, ok := tracker.calls[msg.serial]
	tracker.lck.RUnlock()
	if !ok {
		return
	}
	c.Err = err
	c.Done <- c
	tracker.finalize(msg.serial)
}

func (tracker *callTracker) finalize(sn uint32) {
	tracker.lck.Lock()
	defer tracker.lck.Unlock()
	delete(tracker.calls, sn)
}

func (tracker *callTracker) finalizeAllWithError(err error) {
	closedCalls := make(map[uint32]*Call)
	tracker.lck.RLock()
	for sn, v := range tracker.calls {
		v.Err = err
		v.Done <- v
		closedCalls[sn] = v
	}
	tracker.lck.RUnlock()
	for sn := range closedCalls {
		tracker.finalize(sn)
	}
}
