Unverified Commit 47380709 authored by Stephan Renatus's avatar Stephan Renatus Committed by GitHub

Merge pull request #1338 from srenatus/sr/update-go-ldap

update go-ldap, improve errors
parents 2425c6ea c14b2fd5
...@@ -293,6 +293,9 @@ func (c *ldapConnector) do(ctx context.Context, f func(c *ldap.Conn) error) erro ...@@ -293,6 +293,9 @@ func (c *ldapConnector) do(ctx context.Context, f func(c *ldap.Conn) error) erro
// If bindDN and bindPW are empty this will default to an anonymous bind. // If bindDN and bindPW are empty this will default to an anonymous bind.
if err := conn.Bind(c.BindDN, c.BindPW); err != nil { if err := conn.Bind(c.BindDN, c.BindPW); err != nil {
if c.BindDN == "" && c.BindPW == "" {
return fmt.Errorf("ldap: initial anonymous bind failed: %v", err)
}
return fmt.Errorf("ldap: initial bind for user %q failed: %v", c.BindDN, err) return fmt.Errorf("ldap: initial bind for user %q failed: %v", c.BindDN, err)
} }
...@@ -472,7 +475,7 @@ func (c *ldapConnector) Login(ctx context.Context, s connector.Scopes, username, ...@@ -472,7 +475,7 @@ func (c *ldapConnector) Login(ctx context.Context, s connector.Scopes, username,
func (c *ldapConnector) Refresh(ctx context.Context, s connector.Scopes, ident connector.Identity) (connector.Identity, error) { func (c *ldapConnector) Refresh(ctx context.Context, s connector.Scopes, ident connector.Identity) (connector.Identity, error) {
var data refreshData var data refreshData
if err := json.Unmarshal(ident.ConnectorData, &data); err != nil { if err := json.Unmarshal(ident.ConnectorData, &data); err != nil {
return ident, fmt.Errorf("ldap: failed to unamrshal internal data: %v", err) return ident, fmt.Errorf("ldap: failed to unmarshal internal data: %v", err)
} }
var user ldap.Entry var user ldap.Entry
......
hash: fe29de07f5c1580c51de0e78796bce522d933602d88a4c397b586bd88ca7ca76 hash: e5972bbdf15ad612d99ce8cd34e19537b9eacb5ff53688f339e0da285eb8ec22
updated: 2018-10-24T14:58:32.448481302-07:00 updated: 2018-11-12T19:38:56.235070564+01:00
imports: imports:
- name: github.com/beevik/etree - name: github.com/beevik/etree
version: 4cd0dd976db869f817248477718071a28e978df0 version: 4cd0dd976db869f817248477718071a28e978df0
...@@ -180,7 +180,7 @@ imports: ...@@ -180,7 +180,7 @@ imports:
- name: gopkg.in/asn1-ber.v1 - name: gopkg.in/asn1-ber.v1
version: 4e86f4367175e39f69d9358a5f17b4dda270378d version: 4e86f4367175e39f69d9358a5f17b4dda270378d
- name: gopkg.in/ldap.v2 - name: gopkg.in/ldap.v2
version: 0e7db8eb77695b5a952f0e5d78df9ab160050c73 version: bb7a9ca6e4fbc2129e3db588a34bc970ffe811a9
- name: gopkg.in/square/go-jose.v2 - name: gopkg.in/square/go-jose.v2
version: 8254d6c783765f38c8675fae4427a1fe73fbd09d version: 8254d6c783765f38c8675fae4427a1fe73fbd09d
subpackages: subpackages:
......
...@@ -19,7 +19,7 @@ import: ...@@ -19,7 +19,7 @@ import:
# LDAP dependencies. # LDAP dependencies.
- package: gopkg.in/ldap.v2 - package: gopkg.in/ldap.v2
version: 0e7db8eb77695b5a952f0e5d78df9ab160050c73 version: v2.5.1
- package: gopkg.in/asn1-ber.v1 - package: gopkg.in/asn1-ber.v1
version: 4e86f4367175e39f69d9358a5f17b4dda270378d version: 4e86f4367175e39f69d9358a5f17b4dda270378d
......
Copyright (c) 2012 The Go Authors. All rights reserved. The MIT License (MIT)
Redistribution and use in source and binary forms, with or without Copyright (c) 2011-2015 Michael Mitton (mmitton@gmail.com)
modification, are permitted provided that the following conditions are Portions copyright (c) 2015-2016 go-ldap Authors
met:
* Redistributions of source code must retain the above copyright Permission is hereby granted, free of charge, to any person obtaining a copy
notice, this list of conditions and the following disclaimer. of this software and associated documentation files (the "Software"), to deal
* Redistributions in binary form must reproduce the above in the Software without restriction, including without limitation the rights
copyright notice, this list of conditions and the following disclaimer to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
in the documentation and/or other materials provided with the copies of the Software, and to permit persons to whom the Software is
distribution. furnished to do so, subject to the following conditions:
* Neither the name of Google Inc. nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS The above copyright notice and this permission notice shall be included in all
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT copies or substantial portions of the Software.
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE SOFTWARE.
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
...@@ -16,73 +16,78 @@ import ( ...@@ -16,73 +16,78 @@ import (
"gopkg.in/asn1-ber.v1" "gopkg.in/asn1-ber.v1"
) )
// Attribute represents an LDAP attribute
type Attribute struct { type Attribute struct {
attrType string // Type is the name of the LDAP attribute
attrVals []string Type string
// Vals are the LDAP attribute values
Vals []string
} }
func (a *Attribute) encode() *ber.Packet { func (a *Attribute) encode() *ber.Packet {
seq := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Attribute") seq := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Attribute")
seq.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, a.attrType, "Type")) seq.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, a.Type, "Type"))
set := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSet, nil, "AttributeValue") set := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSet, nil, "AttributeValue")
for _, value := range a.attrVals { for _, value := range a.Vals {
set.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, value, "Vals")) set.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, value, "Vals"))
} }
seq.AppendChild(set) seq.AppendChild(set)
return seq return seq
} }
// AddRequest represents an LDAP AddRequest operation
type AddRequest struct { type AddRequest struct {
dn string // DN identifies the entry being added
attributes []Attribute DN string
// Attributes list the attributes of the new entry
Attributes []Attribute
} }
func (a AddRequest) encode() *ber.Packet { func (a AddRequest) encode() *ber.Packet {
request := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationAddRequest, nil, "Add Request") request := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationAddRequest, nil, "Add Request")
request.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, a.dn, "DN")) request.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, a.DN, "DN"))
attributes := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Attributes") attributes := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Attributes")
for _, attribute := range a.attributes { for _, attribute := range a.Attributes {
attributes.AppendChild(attribute.encode()) attributes.AppendChild(attribute.encode())
} }
request.AppendChild(attributes) request.AppendChild(attributes)
return request return request
} }
// Attribute adds an attribute with the given type and values
func (a *AddRequest) Attribute(attrType string, attrVals []string) { func (a *AddRequest) Attribute(attrType string, attrVals []string) {
a.attributes = append(a.attributes, Attribute{attrType: attrType, attrVals: attrVals}) a.Attributes = append(a.Attributes, Attribute{Type: attrType, Vals: attrVals})
} }
// NewAddRequest returns an AddRequest for the given DN, with no attributes
func NewAddRequest(dn string) *AddRequest { func NewAddRequest(dn string) *AddRequest {
return &AddRequest{ return &AddRequest{
dn: dn, DN: dn,
} }
} }
// Add performs the given AddRequest
func (l *Conn) Add(addRequest *AddRequest) error { func (l *Conn) Add(addRequest *AddRequest) error {
messageID := l.nextMessageID()
packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request") packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request")
packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "MessageID")) packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID"))
packet.AppendChild(addRequest.encode()) packet.AppendChild(addRequest.encode())
l.Debug.PrintPacket(packet) l.Debug.PrintPacket(packet)
channel, err := l.sendMessage(packet) msgCtx, err := l.sendMessage(packet)
if err != nil { if err != nil {
return err return err
} }
if channel == nil { defer l.finishMessage(msgCtx)
return NewError(ErrorNetwork, errors.New("ldap: could not send message"))
}
defer l.finishMessage(messageID)
l.Debug.Printf("%d: waiting for response", messageID) l.Debug.Printf("%d: waiting for response", msgCtx.id)
packetResponse, ok := <-channel packetResponse, ok := <-msgCtx.responses
if !ok { if !ok {
return NewError(ErrorNetwork, errors.New("ldap: channel closed")) return NewError(ErrorNetwork, errors.New("ldap: response channel closed"))
} }
packet, err = packetResponse.ReadPacket() packet, err = packetResponse.ReadPacket()
l.Debug.Printf("%d: got response %p", messageID, packet) l.Debug.Printf("%d: got response %p", msgCtx.id, packet)
if err != nil { if err != nil {
return err return err
} }
...@@ -103,6 +108,6 @@ func (l *Conn) Add(addRequest *AddRequest) error { ...@@ -103,6 +108,6 @@ func (l *Conn) Add(addRequest *AddRequest) error {
log.Printf("Unexpected Response: %d", packet.Children[1].Tag) log.Printf("Unexpected Response: %d", packet.Children[1].Tag)
} }
l.Debug.Printf("%d: returning", messageID) l.Debug.Printf("%d: returning", msgCtx.id)
return nil return nil
} }
// +build go1.4
package ldap
import (
"sync/atomic"
)
// For compilers that support it, we just use the underlying sync/atomic.Value
// type.
type atomicValue struct {
atomic.Value
}
// +build !go1.4
package ldap
import (
"sync"
)
// This is a helper type that emulates the use of the "sync/atomic.Value"
// struct that's available in Go 1.4 and up.
type atomicValue struct {
value interface{}
lock sync.RWMutex
}
func (av *atomicValue) Store(val interface{}) {
av.lock.Lock()
av.value = val
av.lock.Unlock()
}
func (av *atomicValue) Load() interface{} {
av.lock.RLock()
ret := av.value
av.lock.RUnlock()
return ret
}
...@@ -10,16 +10,22 @@ import ( ...@@ -10,16 +10,22 @@ import (
"gopkg.in/asn1-ber.v1" "gopkg.in/asn1-ber.v1"
) )
// SimpleBindRequest represents a username/password bind operation
type SimpleBindRequest struct { type SimpleBindRequest struct {
// Username is the name of the Directory object that the client wishes to bind as
Username string Username string
// Password is the credentials to bind with
Password string Password string
// Controls are optional controls to send with the bind request
Controls []Control Controls []Control
} }
// SimpleBindResult contains the response from the server
type SimpleBindResult struct { type SimpleBindResult struct {
Controls []Control Controls []Control
} }
// NewSimpleBindRequest returns a bind request
func NewSimpleBindRequest(username string, password string, controls []Control) *SimpleBindRequest { func NewSimpleBindRequest(username string, password string, controls []Control) *SimpleBindRequest {
return &SimpleBindRequest{ return &SimpleBindRequest{
Username: username, Username: username,
...@@ -39,11 +45,10 @@ func (bindRequest *SimpleBindRequest) encode() *ber.Packet { ...@@ -39,11 +45,10 @@ func (bindRequest *SimpleBindRequest) encode() *ber.Packet {
return request return request
} }
// SimpleBind performs the simple bind operation defined in the given request
func (l *Conn) SimpleBind(simpleBindRequest *SimpleBindRequest) (*SimpleBindResult, error) { func (l *Conn) SimpleBind(simpleBindRequest *SimpleBindRequest) (*SimpleBindResult, error) {
messageID := l.nextMessageID()
packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request") packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request")
packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "MessageID")) packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID"))
encodedBindRequest := simpleBindRequest.encode() encodedBindRequest := simpleBindRequest.encode()
packet.AppendChild(encodedBindRequest) packet.AppendChild(encodedBindRequest)
...@@ -51,21 +56,18 @@ func (l *Conn) SimpleBind(simpleBindRequest *SimpleBindRequest) (*SimpleBindResu ...@@ -51,21 +56,18 @@ func (l *Conn) SimpleBind(simpleBindRequest *SimpleBindRequest) (*SimpleBindResu
ber.PrintPacket(packet) ber.PrintPacket(packet)
} }
channel, err := l.sendMessage(packet) msgCtx, err := l.sendMessage(packet)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if channel == nil { defer l.finishMessage(msgCtx)
return nil, NewError(ErrorNetwork, errors.New("ldap: could not send message"))
}
defer l.finishMessage(messageID)
packetResponse, ok := <-channel packetResponse, ok := <-msgCtx.responses
if !ok { if !ok {
return nil, NewError(ErrorNetwork, errors.New("ldap: channel closed")) return nil, NewError(ErrorNetwork, errors.New("ldap: response channel closed"))
} }
packet, err = packetResponse.ReadPacket() packet, err = packetResponse.ReadPacket()
l.Debug.Printf("%d: got response %p", messageID, packet) l.Debug.Printf("%d: got response %p", msgCtx.id, packet)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -95,11 +97,10 @@ func (l *Conn) SimpleBind(simpleBindRequest *SimpleBindRequest) (*SimpleBindResu ...@@ -95,11 +97,10 @@ func (l *Conn) SimpleBind(simpleBindRequest *SimpleBindRequest) (*SimpleBindResu
return result, nil return result, nil
} }
// Bind performs a bind with the given username and password
func (l *Conn) Bind(username, password string) error { func (l *Conn) Bind(username, password string) error {
messageID := l.nextMessageID()
packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request") packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request")
packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "MessageID")) packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID"))
bindRequest := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationBindRequest, nil, "Bind Request") bindRequest := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationBindRequest, nil, "Bind Request")
bindRequest.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, 3, "Version")) bindRequest.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, 3, "Version"))
bindRequest.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, username, "User Name")) bindRequest.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, username, "User Name"))
...@@ -110,21 +111,18 @@ func (l *Conn) Bind(username, password string) error { ...@@ -110,21 +111,18 @@ func (l *Conn) Bind(username, password string) error {
ber.PrintPacket(packet) ber.PrintPacket(packet)
} }
channel, err := l.sendMessage(packet) msgCtx, err := l.sendMessage(packet)
if err != nil { if err != nil {
return err return err
} }
if channel == nil { defer l.finishMessage(msgCtx)
return NewError(ErrorNetwork, errors.New("ldap: could not send message"))
}
defer l.finishMessage(messageID)
packetResponse, ok := <-channel packetResponse, ok := <-msgCtx.responses
if !ok { if !ok {
return NewError(ErrorNetwork, errors.New("ldap: channel closed")) return NewError(ErrorNetwork, errors.New("ldap: response channel closed"))
} }
packet, err = packetResponse.ReadPacket() packet, err = packetResponse.ReadPacket()
l.Debug.Printf("%d: got response %p", messageID, packet) l.Debug.Printf("%d: got response %p", msgCtx.id, packet)
if err != nil { if err != nil {
return err return err
} }
......
...@@ -33,9 +33,8 @@ import ( ...@@ -33,9 +33,8 @@ import (
// Compare checks to see if the attribute of the dn matches value. Returns true if it does otherwise // Compare checks to see if the attribute of the dn matches value. Returns true if it does otherwise
// false with any error that occurs if any. // false with any error that occurs if any.
func (l *Conn) Compare(dn, attribute, value string) (bool, error) { func (l *Conn) Compare(dn, attribute, value string) (bool, error) {
messageID := l.nextMessageID()
packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request") packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request")
packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "MessageID")) packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID"))
request := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationCompareRequest, nil, "Compare Request") request := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationCompareRequest, nil, "Compare Request")
request.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, dn, "DN")) request.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, dn, "DN"))
...@@ -48,22 +47,19 @@ func (l *Conn) Compare(dn, attribute, value string) (bool, error) { ...@@ -48,22 +47,19 @@ func (l *Conn) Compare(dn, attribute, value string) (bool, error) {
l.Debug.PrintPacket(packet) l.Debug.PrintPacket(packet)
channel, err := l.sendMessage(packet) msgCtx, err := l.sendMessage(packet)
if err != nil { if err != nil {
return false, err return false, err
} }
if channel == nil { defer l.finishMessage(msgCtx)
return false, NewError(ErrorNetwork, errors.New("ldap: could not send message"))
}
defer l.finishMessage(messageID)
l.Debug.Printf("%d: waiting for response", messageID) l.Debug.Printf("%d: waiting for response", msgCtx.id)
packetResponse, ok := <-channel packetResponse, ok := <-msgCtx.responses
if !ok { if !ok {
return false, NewError(ErrorNetwork, errors.New("ldap: channel closed")) return false, NewError(ErrorNetwork, errors.New("ldap: response channel closed"))
} }
packet, err = packetResponse.ReadPacket() packet, err = packetResponse.ReadPacket()
l.Debug.Printf("%d: got response %p", messageID, packet) l.Debug.Printf("%d: got response %p", msgCtx.id, packet)
if err != nil { if err != nil {
return false, err return false, err
} }
......
This diff is collapsed.
This diff is collapsed.
...@@ -6,7 +6,7 @@ import ( ...@@ -6,7 +6,7 @@ import (
"gopkg.in/asn1-ber.v1" "gopkg.in/asn1-ber.v1"
) )
// debbuging type // debugging type
// - has a Printf method to write the debug output // - has a Printf method to write the debug output
type debugging bool type debugging bool
......
...@@ -12,8 +12,11 @@ import ( ...@@ -12,8 +12,11 @@ import (
"gopkg.in/asn1-ber.v1" "gopkg.in/asn1-ber.v1"
) )
// DelRequest implements an LDAP deletion request
type DelRequest struct { type DelRequest struct {
// DN is the name of the directory entry to delete
DN string DN string
// Controls hold optional controls to send with the request
Controls []Control Controls []Control
} }
...@@ -23,6 +26,7 @@ func (d DelRequest) encode() *ber.Packet { ...@@ -23,6 +26,7 @@ func (d DelRequest) encode() *ber.Packet {
return request return request
} }
// NewDelRequest creates a delete request for the given DN and controls
func NewDelRequest(DN string, func NewDelRequest(DN string,
Controls []Control) *DelRequest { Controls []Control) *DelRequest {
return &DelRequest{ return &DelRequest{
...@@ -31,10 +35,10 @@ func NewDelRequest(DN string, ...@@ -31,10 +35,10 @@ func NewDelRequest(DN string,
} }
} }
// Del executes the given delete request
func (l *Conn) Del(delRequest *DelRequest) error { func (l *Conn) Del(delRequest *DelRequest) error {
messageID := l.nextMessageID()
packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request") packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request")
packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "MessageID")) packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID"))
packet.AppendChild(delRequest.encode()) packet.AppendChild(delRequest.encode())
if delRequest.Controls != nil { if delRequest.Controls != nil {
packet.AppendChild(encodeControls(delRequest.Controls)) packet.AppendChild(encodeControls(delRequest.Controls))
...@@ -42,22 +46,19 @@ func (l *Conn) Del(delRequest *DelRequest) error { ...@@ -42,22 +46,19 @@ func (l *Conn) Del(delRequest *DelRequest) error {
l.Debug.PrintPacket(packet) l.Debug.PrintPacket(packet)
channel, err := l.sendMessage(packet) msgCtx, err := l.sendMessage(packet)
if err != nil { if err != nil {
return err return err
} }
if channel == nil { defer l.finishMessage(msgCtx)
return NewError(ErrorNetwork, errors.New("ldap: could not send message"))
}
defer l.finishMessage(messageID)
l.Debug.Printf("%d: waiting for response", messageID) l.Debug.Printf("%d: waiting for response", msgCtx.id)
packetResponse, ok := <-channel packetResponse, ok := <-msgCtx.responses
if !ok { if !ok {
return NewError(ErrorNetwork, errors.New("ldap: channel closed")) return NewError(ErrorNetwork, errors.New("ldap: response channel closed"))
} }
packet, err = packetResponse.ReadPacket() packet, err = packetResponse.ReadPacket()
l.Debug.Printf("%d: got response %p", messageID, packet) l.Debug.Printf("%d: got response %p", msgCtx.id, packet)
if err != nil { if err != nil {
return err return err
} }
...@@ -78,6 +79,6 @@ func (l *Conn) Del(delRequest *DelRequest) error { ...@@ -78,6 +79,6 @@ func (l *Conn) Del(delRequest *DelRequest) error {
log.Printf("Unexpected Response: %d", packet.Children[1].Tag) log.Printf("Unexpected Response: %d", packet.Children[1].Tag)
} }
l.Debug.Printf("%d: returning", messageID) l.Debug.Printf("%d: returning", msgCtx.id)
return nil return nil
} }
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// //
// File contains DN parsing functionallity // File contains DN parsing functionality
// //
// https://tools.ietf.org/html/rfc4514 // https://tools.ietf.org/html/rfc4514
// //
...@@ -52,22 +52,28 @@ import ( ...@@ -52,22 +52,28 @@ import (
"fmt" "fmt"
"strings" "strings"
ber "gopkg.in/asn1-ber.v1" "gopkg.in/asn1-ber.v1"
) )
// AttributeTypeAndValue represents an attributeTypeAndValue from https://tools.ietf.org/html/rfc4514
type AttributeTypeAndValue struct { type AttributeTypeAndValue struct {
// Type is the attribute type
Type string Type string
// Value is the attribute value
Value string Value string
} }
// RelativeDN represents a relativeDistinguishedName from https://tools.ietf.org/html/rfc4514
type RelativeDN struct { type RelativeDN struct {
Attributes []*AttributeTypeAndValue Attributes []*AttributeTypeAndValue
} }
// DN represents a distinguishedName from https://tools.ietf.org/html/rfc4514
type DN struct { type DN struct {
RDNs []*RelativeDN RDNs []*RelativeDN
} }
// ParseDN returns a distinguishedName or an error
func ParseDN(str string) (*DN, error) { func ParseDN(str string) (*DN, error) {
dn := new(DN) dn := new(DN)
dn.RDNs = make([]*RelativeDN, 0) dn.RDNs = make([]*RelativeDN, 0)
...@@ -77,9 +83,19 @@ func ParseDN(str string) (*DN, error) { ...@@ -77,9 +83,19 @@ func ParseDN(str string) (*DN, error) {
attribute := new(AttributeTypeAndValue) attribute := new(AttributeTypeAndValue)
escaping := false escaping := false
unescapedTrailingSpaces := 0
stringFromBuffer := func() string {
s := buffer.String()
s = s[0 : len(s)-unescapedTrailingSpaces]
buffer.Reset()
unescapedTrailingSpaces = 0
return s
}
for i := 0; i < len(str); i++ { for i := 0; i < len(str); i++ {
char := str[i] char := str[i]
if escaping { if escaping {
unescapedTrailingSpaces = 0
escaping = false escaping = false
switch char { switch char {
case ' ', '"', '#', '+', ',', ';', '<', '=', '>', '\\': case ' ', '"', '#', '+', ',', ';', '<', '=', '>', '\\':
...@@ -94,19 +110,17 @@ func ParseDN(str string) (*DN, error) { ...@@ -94,19 +110,17 @@ func ParseDN(str string) (*DN, error) {
dst := []byte{0} dst := []byte{0}
n, err := enchex.Decode([]byte(dst), []byte(str[i:i+2])) n, err := enchex.Decode([]byte(dst), []byte(str[i:i+2]))
if err != nil { if err != nil {
return nil, errors.New( return nil, fmt.Errorf("Failed to decode escaped character: %s", err)
fmt.Sprintf("Failed to decode escaped character: %s", err))
} else if n != 1 { } else if n != 1 {
return nil, errors.New( return nil, fmt.Errorf("Expected 1 byte when un-escaping, got %d", n)
fmt.Sprintf("Expected 1 byte when un-escaping, got %d", n))
} }
buffer.WriteByte(dst[0]) buffer.WriteByte(dst[0])
i++ i++
} else if char == '\\' { } else if char == '\\' {
unescapedTrailingSpaces = 0
escaping = true escaping = true
} else if char == '=' { } else if char == '=' {
attribute.Type = buffer.String() attribute.Type = stringFromBuffer()
buffer.Reset()
// Special case: If the first character in the value is # the // Special case: If the first character in the value is # the
// following data is BER encoded so we can just fast forward // following data is BER encoded so we can just fast forward
// and decode. // and decode.
...@@ -119,18 +133,20 @@ func ParseDN(str string) (*DN, error) { ...@@ -119,18 +133,20 @@ func ParseDN(str string) (*DN, error) {
} else { } else {
data = str[i:] data = str[i:]
} }
raw_ber, err := enchex.DecodeString(data) rawBER, err := enchex.DecodeString(data)
if err != nil { if err != nil {
return nil, errors.New( return nil, fmt.Errorf("Failed to decode BER encoding: %s", err)
fmt.Sprintf("Failed to decode BER encoding: %s", err))
} }
packet := ber.DecodePacket(raw_ber) packet := ber.DecodePacket(rawBER)
buffer.WriteString(packet.Data.String()) buffer.WriteString(packet.Data.String())
i += len(data) - 1 i += len(data) - 1
} }
} else if char == ',' || char == '+' { } else if char == ',' || char == '+' {
// We're done with this RDN or value, push it // We're done with this RDN or value, push it
attribute.Value = buffer.String() if len(attribute.Type) == 0 {
return nil, errors.New("incomplete type, value pair")
}
attribute.Value = stringFromBuffer()
rdn.Attributes = append(rdn.Attributes, attribute) rdn.Attributes = append(rdn.Attributes, attribute)
attribute = new(AttributeTypeAndValue) attribute = new(AttributeTypeAndValue)
if char == ',' { if char == ',' {
...@@ -138,8 +154,17 @@ func ParseDN(str string) (*DN, error) { ...@@ -138,8 +154,17 @@ func ParseDN(str string) (*DN, error) {
rdn = new(RelativeDN) rdn = new(RelativeDN)
rdn.Attributes = make([]*AttributeTypeAndValue, 0) rdn.Attributes = make([]*AttributeTypeAndValue, 0)
} }
buffer.Reset() } else if char == ' ' && buffer.Len() == 0 {
// ignore unescaped leading spaces
continue
} else { } else {
if char == ' ' {
// Track unescaped spaces in case they are trailing and we need to remove them
unescapedTrailingSpaces++
} else {
// Reset if we see a non-space char
unescapedTrailingSpaces = 0
}
buffer.WriteByte(char) buffer.WriteByte(char)
} }
} }
...@@ -147,9 +172,76 @@ func ParseDN(str string) (*DN, error) { ...@@ -147,9 +172,76 @@ func ParseDN(str string) (*DN, error) {
if len(attribute.Type) == 0 { if len(attribute.Type) == 0 {
return nil, errors.New("DN ended with incomplete type, value pair") return nil, errors.New("DN ended with incomplete type, value pair")
} }
attribute.Value = buffer.String() attribute.Value = stringFromBuffer()
rdn.Attributes = append(rdn.Attributes, attribute) rdn.Attributes = append(rdn.Attributes, attribute)
dn.RDNs = append(dn.RDNs, rdn) dn.RDNs = append(dn.RDNs, rdn)
} }
return dn, nil return dn, nil
} }
// Equal returns true if the DNs are equal as defined by rfc4517 4.2.15 (distinguishedNameMatch).
// Returns true if they have the same number of relative distinguished names
// and corresponding relative distinguished names (by position) are the same.
func (d *DN) Equal(other *DN) bool {
if len(d.RDNs) != len(other.RDNs) {
return false
}
for i := range d.RDNs {
if !d.RDNs[i].Equal(other.RDNs[i]) {
return false
}
}
return true
}
// AncestorOf returns true if the other DN consists of at least one RDN followed by all the RDNs of the current DN.
// "ou=widgets,o=acme.com" is an ancestor of "ou=sprockets,ou=widgets,o=acme.com"
// "ou=widgets,o=acme.com" is not an ancestor of "ou=sprockets,ou=widgets,o=foo.com"
// "ou=widgets,o=acme.com" is not an ancestor of "ou=widgets,o=acme.com"
func (d *DN) AncestorOf(other *DN) bool {
if len(d.RDNs) >= len(other.RDNs) {
return false
}
// Take the last `len(d.RDNs)` RDNs from the other DN to compare against
otherRDNs := other.RDNs[len(other.RDNs)-len(d.RDNs):]
for i := range d.RDNs {
if !d.RDNs[i].Equal(otherRDNs[i]) {
return false
}
}
return true
}
// Equal returns true if the RelativeDNs are equal as defined by rfc4517 4.2.15 (distinguishedNameMatch).
// Relative distinguished names are the same if and only if they have the same number of AttributeTypeAndValues
// and each attribute of the first RDN is the same as the attribute of the second RDN with the same attribute type.
// The order of attributes is not significant.
// Case of attribute types is not significant.
func (r *RelativeDN) Equal(other *RelativeDN) bool {
if len(r.Attributes) != len(other.Attributes) {
return false
}
return r.hasAllAttributes(other.Attributes) && other.hasAllAttributes(r.Attributes)
}
func (r *RelativeDN) hasAllAttributes(attrs []*AttributeTypeAndValue) bool {
for _, attr := range attrs {
found := false
for _, myattr := range r.Attributes {
if myattr.Equal(attr) {
found = true
break
}
}
if !found {
return false
}
}
return true
}
// Equal returns true if the AttributeTypeAndValue is equivalent to the specified AttributeTypeAndValue
// Case of the attribute type is not significant
func (a *AttributeTypeAndValue) Equal(other *AttributeTypeAndValue) bool {
return strings.EqualFold(a.Type, other.Type) && a.Value == other.Value
}
...@@ -56,6 +56,7 @@ const ( ...@@ -56,6 +56,7 @@ const (
ErrorUnexpectedResponse = 205 ErrorUnexpectedResponse = 205
) )
// LDAPResultCodeMap contains string descriptions for LDAP error codes
var LDAPResultCodeMap = map[uint8]string{ var LDAPResultCodeMap = map[uint8]string{
LDAPResultSuccess: "Success", LDAPResultSuccess: "Success",
LDAPResultOperationsError: "Operations Error", LDAPResultOperationsError: "Operations Error",
...@@ -96,6 +97,13 @@ var LDAPResultCodeMap = map[uint8]string{ ...@@ -96,6 +97,13 @@ var LDAPResultCodeMap = map[uint8]string{
LDAPResultObjectClassModsProhibited: "Object Class Mods Prohibited", LDAPResultObjectClassModsProhibited: "Object Class Mods Prohibited",
LDAPResultAffectsMultipleDSAs: "Affects Multiple DSAs", LDAPResultAffectsMultipleDSAs: "Affects Multiple DSAs",
LDAPResultOther: "Other", LDAPResultOther: "Other",
ErrorNetwork: "Network Error",
ErrorFilterCompile: "Filter Compile Error",
ErrorFilterDecompile: "Filter Decompile Error",
ErrorDebugging: "Debugging Error",
ErrorUnexpectedMessage: "Unexpected Message",
ErrorUnexpectedResponse: "Unexpected Response",
} }
func getLDAPResultCode(packet *ber.Packet) (code uint8, description string) { func getLDAPResultCode(packet *ber.Packet) (code uint8, description string) {
...@@ -115,8 +123,11 @@ func getLDAPResultCode(packet *ber.Packet) (code uint8, description string) { ...@@ -115,8 +123,11 @@ func getLDAPResultCode(packet *ber.Packet) (code uint8, description string) {
return ErrorNetwork, "Invalid packet format" return ErrorNetwork, "Invalid packet format"
} }
// Error holds LDAP error information
type Error struct { type Error struct {
// Err is the underlying error
Err error Err error
// ResultCode is the LDAP error code
ResultCode uint8 ResultCode uint8
} }
...@@ -124,10 +135,12 @@ func (e *Error) Error() string { ...@@ -124,10 +135,12 @@ func (e *Error) Error() string {
return fmt.Sprintf("LDAP Result Code %d %q: %s", e.ResultCode, LDAPResultCodeMap[e.ResultCode], e.Err.Error()) return fmt.Sprintf("LDAP Result Code %d %q: %s", e.ResultCode, LDAPResultCodeMap[e.ResultCode], e.Err.Error())
} }
// NewError creates an LDAP error with the given code and underlying error
func NewError(resultCode uint8, err error) error { func NewError(resultCode uint8, err error) error {
return &Error{ResultCode: resultCode, Err: err} return &Error{ResultCode: resultCode, Err: err}
} }
// IsErrorWithCode returns true if the given error is an LDAP error with the given result code
func IsErrorWithCode(err error, desiredResultCode uint8) bool { func IsErrorWithCode(err error, desiredResultCode uint8) bool {
if err == nil { if err == nil {
return false return false
......
...@@ -15,6 +15,7 @@ import ( ...@@ -15,6 +15,7 @@ import (
"gopkg.in/asn1-ber.v1" "gopkg.in/asn1-ber.v1"
) )
// Filter choices
const ( const (
FilterAnd = 0 FilterAnd = 0
FilterOr = 1 FilterOr = 1
...@@ -28,6 +29,7 @@ const ( ...@@ -28,6 +29,7 @@ const (
FilterExtensibleMatch = 9 FilterExtensibleMatch = 9
) )
// FilterMap contains human readable descriptions of Filter choices
var FilterMap = map[uint64]string{ var FilterMap = map[uint64]string{
FilterAnd: "And", FilterAnd: "And",
FilterOr: "Or", FilterOr: "Or",
...@@ -41,18 +43,21 @@ var FilterMap = map[uint64]string{ ...@@ -41,18 +43,21 @@ var FilterMap = map[uint64]string{
FilterExtensibleMatch: "Extensible Match", FilterExtensibleMatch: "Extensible Match",
} }
// SubstringFilter options
const ( const (
FilterSubstringsInitial = 0 FilterSubstringsInitial = 0
FilterSubstringsAny = 1 FilterSubstringsAny = 1
FilterSubstringsFinal = 2 FilterSubstringsFinal = 2
) )
// FilterSubstringsMap contains human readable descriptions of SubstringFilter choices
var FilterSubstringsMap = map[uint64]string{ var FilterSubstringsMap = map[uint64]string{
FilterSubstringsInitial: "Substrings Initial", FilterSubstringsInitial: "Substrings Initial",
FilterSubstringsAny: "Substrings Any", FilterSubstringsAny: "Substrings Any",
FilterSubstringsFinal: "Substrings Final", FilterSubstringsFinal: "Substrings Final",
} }
// MatchingRuleAssertion choices
const ( const (
MatchingRuleAssertionMatchingRule = 1 MatchingRuleAssertionMatchingRule = 1
MatchingRuleAssertionType = 2 MatchingRuleAssertionType = 2
...@@ -60,6 +65,7 @@ const ( ...@@ -60,6 +65,7 @@ const (
MatchingRuleAssertionDNAttributes = 4 MatchingRuleAssertionDNAttributes = 4
) )
// MatchingRuleAssertionMap contains human readable descriptions of MatchingRuleAssertion choices
var MatchingRuleAssertionMap = map[uint64]string{ var MatchingRuleAssertionMap = map[uint64]string{
MatchingRuleAssertionMatchingRule: "Matching Rule Assertion Matching Rule", MatchingRuleAssertionMatchingRule: "Matching Rule Assertion Matching Rule",
MatchingRuleAssertionType: "Matching Rule Assertion Type", MatchingRuleAssertionType: "Matching Rule Assertion Type",
...@@ -67,6 +73,7 @@ var MatchingRuleAssertionMap = map[uint64]string{ ...@@ -67,6 +73,7 @@ var MatchingRuleAssertionMap = map[uint64]string{
MatchingRuleAssertionDNAttributes: "Matching Rule Assertion DN Attributes", MatchingRuleAssertionDNAttributes: "Matching Rule Assertion DN Attributes",
} }
// CompileFilter converts a string representation of a filter into a BER-encoded packet
func CompileFilter(filter string) (*ber.Packet, error) { func CompileFilter(filter string) (*ber.Packet, error) {
if len(filter) == 0 || filter[0] != '(' { if len(filter) == 0 || filter[0] != '(' {
return nil, NewError(ErrorFilterCompile, errors.New("ldap: filter does not start with an '('")) return nil, NewError(ErrorFilterCompile, errors.New("ldap: filter does not start with an '('"))
...@@ -75,12 +82,16 @@ func CompileFilter(filter string) (*ber.Packet, error) { ...@@ -75,12 +82,16 @@ func CompileFilter(filter string) (*ber.Packet, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
if pos != len(filter) { switch {
case pos > len(filter):
return nil, NewError(ErrorFilterCompile, errors.New("ldap: unexpected end of filter"))
case pos < len(filter):
return nil, NewError(ErrorFilterCompile, errors.New("ldap: finished compiling filter with extra at end: "+fmt.Sprint(filter[pos:]))) return nil, NewError(ErrorFilterCompile, errors.New("ldap: finished compiling filter with extra at end: "+fmt.Sprint(filter[pos:])))
} }
return packet, nil return packet, nil
} }
// DecompileFilter converts a packet representation of a filter into a string representation
func DecompileFilter(packet *ber.Packet) (ret string, err error) { func DecompileFilter(packet *ber.Packet) (ret string, err error) {
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
...@@ -239,11 +250,13 @@ func compileFilter(filter string, pos int) (*ber.Packet, int, error) { ...@@ -239,11 +250,13 @@ func compileFilter(filter string, pos int) (*ber.Packet, int, error) {
packet.AppendChild(child) packet.AppendChild(child)
return packet, newPos, err return packet, newPos, err
default: default:
READING_ATTR := 0 const (
READING_EXTENSIBLE_MATCHING_RULE := 1 stateReadingAttr = 0
READING_CONDITION := 2 stateReadingExtensibleMatchingRule = 1
stateReadingCondition = 2
)
state := READING_ATTR state := stateReadingAttr
attribute := "" attribute := ""
extensibleDNAttributes := false extensibleDNAttributes := false
...@@ -261,56 +274,56 @@ func compileFilter(filter string, pos int) (*ber.Packet, int, error) { ...@@ -261,56 +274,56 @@ func compileFilter(filter string, pos int) (*ber.Packet, int, error) {
} }
switch state { switch state {
case READING_ATTR: case stateReadingAttr:
switch { switch {
// Extensible rule, with only DN-matching // Extensible rule, with only DN-matching
case currentRune == ':' && strings.HasPrefix(remainingFilter, ":dn:="): case currentRune == ':' && strings.HasPrefix(remainingFilter, ":dn:="):
packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterExtensibleMatch, nil, FilterMap[FilterExtensibleMatch]) packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterExtensibleMatch, nil, FilterMap[FilterExtensibleMatch])
extensibleDNAttributes = true extensibleDNAttributes = true
state = READING_CONDITION state = stateReadingCondition
newPos += 5 newPos += 5
// Extensible rule, with DN-matching and a matching OID // Extensible rule, with DN-matching and a matching OID
case currentRune == ':' && strings.HasPrefix(remainingFilter, ":dn:"): case currentRune == ':' && strings.HasPrefix(remainingFilter, ":dn:"):
packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterExtensibleMatch, nil, FilterMap[FilterExtensibleMatch]) packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterExtensibleMatch, nil, FilterMap[FilterExtensibleMatch])
extensibleDNAttributes = true extensibleDNAttributes = true
state = READING_EXTENSIBLE_MATCHING_RULE state = stateReadingExtensibleMatchingRule
newPos += 4 newPos += 4
// Extensible rule, with attr only // Extensible rule, with attr only
case currentRune == ':' && strings.HasPrefix(remainingFilter, ":="): case currentRune == ':' && strings.HasPrefix(remainingFilter, ":="):
packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterExtensibleMatch, nil, FilterMap[FilterExtensibleMatch]) packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterExtensibleMatch, nil, FilterMap[FilterExtensibleMatch])
state = READING_CONDITION state = stateReadingCondition
newPos += 2 newPos += 2
// Extensible rule, with no DN attribute matching // Extensible rule, with no DN attribute matching
case currentRune == ':': case currentRune == ':':
packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterExtensibleMatch, nil, FilterMap[FilterExtensibleMatch]) packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterExtensibleMatch, nil, FilterMap[FilterExtensibleMatch])
state = READING_EXTENSIBLE_MATCHING_RULE state = stateReadingExtensibleMatchingRule
newPos += 1 newPos++
// Equality condition // Equality condition
case currentRune == '=': case currentRune == '=':
packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterEqualityMatch, nil, FilterMap[FilterEqualityMatch]) packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterEqualityMatch, nil, FilterMap[FilterEqualityMatch])
state = READING_CONDITION state = stateReadingCondition
newPos += 1 newPos++
// Greater-than or equal // Greater-than or equal
case currentRune == '>' && strings.HasPrefix(remainingFilter, ">="): case currentRune == '>' && strings.HasPrefix(remainingFilter, ">="):
packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterGreaterOrEqual, nil, FilterMap[FilterGreaterOrEqual]) packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterGreaterOrEqual, nil, FilterMap[FilterGreaterOrEqual])
state = READING_CONDITION state = stateReadingCondition
newPos += 2 newPos += 2
// Less-than or equal // Less-than or equal
case currentRune == '<' && strings.HasPrefix(remainingFilter, "<="): case currentRune == '<' && strings.HasPrefix(remainingFilter, "<="):
packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterLessOrEqual, nil, FilterMap[FilterLessOrEqual]) packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterLessOrEqual, nil, FilterMap[FilterLessOrEqual])
state = READING_CONDITION state = stateReadingCondition
newPos += 2 newPos += 2
// Approx // Approx
case currentRune == '~' && strings.HasPrefix(remainingFilter, "~="): case currentRune == '~' && strings.HasPrefix(remainingFilter, "~="):
packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterApproxMatch, nil, FilterMap[FilterApproxMatch]) packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterApproxMatch, nil, FilterMap[FilterApproxMatch])
state = READING_CONDITION state = stateReadingCondition
newPos += 2 newPos += 2
// Still reading the attribute name // Still reading the attribute name
...@@ -319,12 +332,12 @@ func compileFilter(filter string, pos int) (*ber.Packet, int, error) { ...@@ -319,12 +332,12 @@ func compileFilter(filter string, pos int) (*ber.Packet, int, error) {
newPos += currentWidth newPos += currentWidth
} }
case READING_EXTENSIBLE_MATCHING_RULE: case stateReadingExtensibleMatchingRule:
switch { switch {
// Matching rule OID is done // Matching rule OID is done
case currentRune == ':' && strings.HasPrefix(remainingFilter, ":="): case currentRune == ':' && strings.HasPrefix(remainingFilter, ":="):
state = READING_CONDITION state = stateReadingCondition
newPos += 2 newPos += 2
// Still reading the matching rule oid // Still reading the matching rule oid
...@@ -333,7 +346,7 @@ func compileFilter(filter string, pos int) (*ber.Packet, int, error) { ...@@ -333,7 +346,7 @@ func compileFilter(filter string, pos int) (*ber.Packet, int, error) {
newPos += currentWidth newPos += currentWidth
} }
case READING_CONDITION: case stateReadingCondition:
// append to the condition // append to the condition
condition += fmt.Sprintf("%c", currentRune) condition += fmt.Sprintf("%c", currentRune)
newPos += currentWidth newPos += currentWidth
...@@ -369,9 +382,9 @@ func compileFilter(filter string, pos int) (*ber.Packet, int, error) { ...@@ -369,9 +382,9 @@ func compileFilter(filter string, pos int) (*ber.Packet, int, error) {
} }
// Add the value (only required child) // Add the value (only required child)
encodedString, err := escapedStringToEncodedBytes(condition) encodedString, encodeErr := escapedStringToEncodedBytes(condition)
if err != nil { if encodeErr != nil {
return packet, newPos, err return packet, newPos, encodeErr
} }
packet.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, MatchingRuleAssertionMatchValue, encodedString, MatchingRuleAssertionMap[MatchingRuleAssertionMatchValue])) packet.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, MatchingRuleAssertionMatchValue, encodedString, MatchingRuleAssertionMap[MatchingRuleAssertionMatchValue]))
...@@ -401,17 +414,17 @@ func compileFilter(filter string, pos int) (*ber.Packet, int, error) { ...@@ -401,17 +414,17 @@ func compileFilter(filter string, pos int) (*ber.Packet, int, error) {
default: default:
tag = FilterSubstringsAny tag = FilterSubstringsAny
} }
encodedString, err := escapedStringToEncodedBytes(part) encodedString, encodeErr := escapedStringToEncodedBytes(part)
if err != nil { if encodeErr != nil {
return packet, newPos, err return packet, newPos, encodeErr
} }
seq.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, tag, encodedString, FilterSubstringsMap[uint64(tag)])) seq.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, tag, encodedString, FilterSubstringsMap[uint64(tag)]))
} }
packet.AppendChild(seq) packet.AppendChild(seq)
default: default:
encodedString, err := escapedStringToEncodedBytes(condition) encodedString, encodeErr := escapedStringToEncodedBytes(condition)
if err != nil { if encodeErr != nil {
return packet, newPos, err return packet, newPos, encodeErr
} }
packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, attribute, "Attribute")) packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, attribute, "Attribute"))
packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, encodedString, "Condition")) packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, encodedString, "Condition"))
...@@ -440,12 +453,12 @@ func escapedStringToEncodedBytes(escapedString string) (string, error) { ...@@ -440,12 +453,12 @@ func escapedStringToEncodedBytes(escapedString string) (string, error) {
if i+2 > len(escapedString) { if i+2 > len(escapedString) {
return "", NewError(ErrorFilterCompile, errors.New("ldap: missing characters for escape in filter")) return "", NewError(ErrorFilterCompile, errors.New("ldap: missing characters for escape in filter"))
} }
if escByte, decodeErr := hexpac.DecodeString(escapedString[i+1 : i+3]); decodeErr != nil { escByte, decodeErr := hexpac.DecodeString(escapedString[i+1 : i+3])
if decodeErr != nil {
return "", NewError(ErrorFilterCompile, errors.New("ldap: invalid characters for escape in filter")) return "", NewError(ErrorFilterCompile, errors.New("ldap: invalid characters for escape in filter"))
} else { }
buffer.WriteByte(escByte[0]) buffer.WriteByte(escByte[0])
i += 2 // +1 from end of loop, so 3 total for \xx. i += 2 // +1 from end of loop, so 3 total for \xx.
}
} else { } else {
buffer.WriteRune(currentRune) buffer.WriteRune(currentRune)
} }
......
...@@ -9,7 +9,7 @@ import ( ...@@ -9,7 +9,7 @@ import (
"io/ioutil" "io/ioutil"
"os" "os"
ber "gopkg.in/asn1-ber.v1" "gopkg.in/asn1-ber.v1"
) )
// LDAP Application Codes // LDAP Application Codes
...@@ -36,6 +36,7 @@ const ( ...@@ -36,6 +36,7 @@ const (
ApplicationExtendedResponse = 24 ApplicationExtendedResponse = 24
) )
// ApplicationMap contains human readable descriptions of LDAP Application Codes
var ApplicationMap = map[uint8]string{ var ApplicationMap = map[uint8]string{
ApplicationBindRequest: "Bind Request", ApplicationBindRequest: "Bind Request",
ApplicationBindResponse: "Bind Response", ApplicationBindResponse: "Bind Response",
...@@ -72,6 +73,7 @@ const ( ...@@ -72,6 +73,7 @@ const (
BeheraPasswordInHistory = 8 BeheraPasswordInHistory = 8
) )
// BeheraPasswordPolicyErrorMap contains human readable descriptions of Behera Password Policy error codes
var BeheraPasswordPolicyErrorMap = map[int8]string{ var BeheraPasswordPolicyErrorMap = map[int8]string{
BeheraPasswordExpired: "Password expired", BeheraPasswordExpired: "Password expired",
BeheraAccountLocked: "Account locked", BeheraAccountLocked: "Account locked",
...@@ -151,16 +153,47 @@ func addLDAPDescriptions(packet *ber.Packet) (err error) { ...@@ -151,16 +153,47 @@ func addLDAPDescriptions(packet *ber.Packet) (err error) {
func addControlDescriptions(packet *ber.Packet) { func addControlDescriptions(packet *ber.Packet) {
packet.Description = "Controls" packet.Description = "Controls"
for _, child := range packet.Children { for _, child := range packet.Children {
var value *ber.Packet
controlType := ""
child.Description = "Control" child.Description = "Control"
child.Children[0].Description = "Control Type (" + ControlTypeMap[child.Children[0].Value.(string)] + ")" switch len(child.Children) {
value := child.Children[1] case 0:
if len(child.Children) == 3 { // at least one child is required for control type
continue
case 1:
// just type, no criticality or value
controlType = child.Children[0].Value.(string)
child.Children[0].Description = "Control Type (" + ControlTypeMap[controlType] + ")"
case 2:
controlType = child.Children[0].Value.(string)
child.Children[0].Description = "Control Type (" + ControlTypeMap[controlType] + ")"
// Children[1] could be criticality or value (both are optional)
// duck-type on whether this is a boolean
if _, ok := child.Children[1].Value.(bool); ok {
child.Children[1].Description = "Criticality" child.Children[1].Description = "Criticality"
value = child.Children[2] } else {
child.Children[1].Description = "Control Value"
value = child.Children[1]
} }
value.Description = "Control Value"
switch child.Children[0].Value.(string) { case 3:
// criticality and value present
controlType = child.Children[0].Value.(string)
child.Children[0].Description = "Control Type (" + ControlTypeMap[controlType] + ")"
child.Children[1].Description = "Criticality"
child.Children[2].Description = "Control Value"
value = child.Children[2]
default:
// more than 3 children is invalid
continue
}
if value == nil {
continue
}
switch controlType {
case ControlTypePaging: case ControlTypePaging:
value.Description += " (Paging)" value.Description += " (Paging)"
if value.Value != nil { if value.Value != nil {
...@@ -186,18 +219,18 @@ func addControlDescriptions(packet *ber.Packet) { ...@@ -186,18 +219,18 @@ func addControlDescriptions(packet *ber.Packet) {
for _, child := range sequence.Children { for _, child := range sequence.Children {
if child.Tag == 0 { if child.Tag == 0 {
//Warning //Warning
child := child.Children[0] warningPacket := child.Children[0]
packet := ber.DecodePacket(child.Data.Bytes()) packet := ber.DecodePacket(warningPacket.Data.Bytes())
val, ok := packet.Value.(int64) val, ok := packet.Value.(int64)
if ok { if ok {
if child.Tag == 0 { if warningPacket.Tag == 0 {
//timeBeforeExpiration //timeBeforeExpiration
value.Description += " (TimeBeforeExpiration)" value.Description += " (TimeBeforeExpiration)"
child.Value = val warningPacket.Value = val
} else if child.Tag == 1 { } else if warningPacket.Tag == 1 {
//graceAuthNsRemaining //graceAuthNsRemaining
value.Description += " (GraceAuthNsRemaining)" value.Description += " (GraceAuthNsRemaining)"
child.Value = val warningPacket.Value = val
} }
} }
} else if child.Tag == 1 { } else if child.Tag == 1 {
...@@ -237,6 +270,7 @@ func addDefaultLDAPResponseDescriptions(packet *ber.Packet) { ...@@ -237,6 +270,7 @@ func addDefaultLDAPResponseDescriptions(packet *ber.Packet) {
} }
} }
// DebugBinaryFile reads and prints packets from the given filename
func DebugBinaryFile(fileName string) error { func DebugBinaryFile(fileName string) error {
file, err := ioutil.ReadFile(fileName) file, err := ioutil.ReadFile(fileName)
if err != nil { if err != nil {
......
...@@ -36,64 +36,76 @@ import ( ...@@ -36,64 +36,76 @@ import (
"gopkg.in/asn1-ber.v1" "gopkg.in/asn1-ber.v1"
) )
// Change operation choices
const ( const (
AddAttribute = 0 AddAttribute = 0
DeleteAttribute = 1 DeleteAttribute = 1
ReplaceAttribute = 2 ReplaceAttribute = 2
) )
// PartialAttribute for a ModifyRequest as defined in https://tools.ietf.org/html/rfc4511
type PartialAttribute struct { type PartialAttribute struct {
attrType string // Type is the type of the partial attribute
attrVals []string Type string
// Vals are the values of the partial attribute
Vals []string
} }
func (p *PartialAttribute) encode() *ber.Packet { func (p *PartialAttribute) encode() *ber.Packet {
seq := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "PartialAttribute") seq := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "PartialAttribute")
seq.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, p.attrType, "Type")) seq.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, p.Type, "Type"))
set := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSet, nil, "AttributeValue") set := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSet, nil, "AttributeValue")
for _, value := range p.attrVals { for _, value := range p.Vals {
set.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, value, "Vals")) set.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, value, "Vals"))
} }
seq.AppendChild(set) seq.AppendChild(set)
return seq return seq
} }
// ModifyRequest as defined in https://tools.ietf.org/html/rfc4511
type ModifyRequest struct { type ModifyRequest struct {
dn string // DN is the distinguishedName of the directory entry to modify
addAttributes []PartialAttribute DN string
deleteAttributes []PartialAttribute // AddAttributes contain the attributes to add
replaceAttributes []PartialAttribute AddAttributes []PartialAttribute
// DeleteAttributes contain the attributes to delete
DeleteAttributes []PartialAttribute
// ReplaceAttributes contain the attributes to replace
ReplaceAttributes []PartialAttribute
} }
// Add inserts the given attribute to the list of attributes to add
func (m *ModifyRequest) Add(attrType string, attrVals []string) { func (m *ModifyRequest) Add(attrType string, attrVals []string) {
m.addAttributes = append(m.addAttributes, PartialAttribute{attrType: attrType, attrVals: attrVals}) m.AddAttributes = append(m.AddAttributes, PartialAttribute{Type: attrType, Vals: attrVals})
} }
// Delete inserts the given attribute to the list of attributes to delete
func (m *ModifyRequest) Delete(attrType string, attrVals []string) { func (m *ModifyRequest) Delete(attrType string, attrVals []string) {
m.deleteAttributes = append(m.deleteAttributes, PartialAttribute{attrType: attrType, attrVals: attrVals}) m.DeleteAttributes = append(m.DeleteAttributes, PartialAttribute{Type: attrType, Vals: attrVals})
} }
// Replace inserts the given attribute to the list of attributes to replace
func (m *ModifyRequest) Replace(attrType string, attrVals []string) { func (m *ModifyRequest) Replace(attrType string, attrVals []string) {
m.replaceAttributes = append(m.replaceAttributes, PartialAttribute{attrType: attrType, attrVals: attrVals}) m.ReplaceAttributes = append(m.ReplaceAttributes, PartialAttribute{Type: attrType, Vals: attrVals})
} }
func (m ModifyRequest) encode() *ber.Packet { func (m ModifyRequest) encode() *ber.Packet {
request := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationModifyRequest, nil, "Modify Request") request := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationModifyRequest, nil, "Modify Request")
request.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, m.dn, "DN")) request.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, m.DN, "DN"))
changes := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Changes") changes := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Changes")
for _, attribute := range m.addAttributes { for _, attribute := range m.AddAttributes {
change := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Change") change := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Change")
change.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, uint64(AddAttribute), "Operation")) change.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, uint64(AddAttribute), "Operation"))
change.AppendChild(attribute.encode()) change.AppendChild(attribute.encode())
changes.AppendChild(change) changes.AppendChild(change)
} }
for _, attribute := range m.deleteAttributes { for _, attribute := range m.DeleteAttributes {
change := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Change") change := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Change")
change.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, uint64(DeleteAttribute), "Operation")) change.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, uint64(DeleteAttribute), "Operation"))
change.AppendChild(attribute.encode()) change.AppendChild(attribute.encode())
changes.AppendChild(change) changes.AppendChild(change)
} }
for _, attribute := range m.replaceAttributes { for _, attribute := range m.ReplaceAttributes {
change := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Change") change := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Change")
change.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, uint64(ReplaceAttribute), "Operation")) change.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, uint64(ReplaceAttribute), "Operation"))
change.AppendChild(attribute.encode()) change.AppendChild(attribute.encode())
...@@ -103,38 +115,36 @@ func (m ModifyRequest) encode() *ber.Packet { ...@@ -103,38 +115,36 @@ func (m ModifyRequest) encode() *ber.Packet {
return request return request
} }
// NewModifyRequest creates a modify request for the given DN
func NewModifyRequest( func NewModifyRequest(
dn string, dn string,
) *ModifyRequest { ) *ModifyRequest {
return &ModifyRequest{ return &ModifyRequest{
dn: dn, DN: dn,
} }
} }
// Modify performs the ModifyRequest
func (l *Conn) Modify(modifyRequest *ModifyRequest) error { func (l *Conn) Modify(modifyRequest *ModifyRequest) error {
messageID := l.nextMessageID()
packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request") packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request")
packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "MessageID")) packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID"))
packet.AppendChild(modifyRequest.encode()) packet.AppendChild(modifyRequest.encode())
l.Debug.PrintPacket(packet) l.Debug.PrintPacket(packet)
channel, err := l.sendMessage(packet) msgCtx, err := l.sendMessage(packet)
if err != nil { if err != nil {
return err return err
} }
if channel == nil { defer l.finishMessage(msgCtx)
return NewError(ErrorNetwork, errors.New("ldap: could not send message"))
}
defer l.finishMessage(messageID)
l.Debug.Printf("%d: waiting for response", messageID) l.Debug.Printf("%d: waiting for response", msgCtx.id)
packetResponse, ok := <-channel packetResponse, ok := <-msgCtx.responses
if !ok { if !ok {
return NewError(ErrorNetwork, errors.New("ldap: channel closed")) return NewError(ErrorNetwork, errors.New("ldap: response channel closed"))
} }
packet, err = packetResponse.ReadPacket() packet, err = packetResponse.ReadPacket()
l.Debug.Printf("%d: got response %p", messageID, packet) l.Debug.Printf("%d: got response %p", msgCtx.id, packet)
if err != nil { if err != nil {
return err return err
} }
...@@ -155,6 +165,6 @@ func (l *Conn) Modify(modifyRequest *ModifyRequest) error { ...@@ -155,6 +165,6 @@ func (l *Conn) Modify(modifyRequest *ModifyRequest) error {
log.Printf("Unexpected Response: %d", packet.Children[1].Tag) log.Printf("Unexpected Response: %d", packet.Children[1].Tag)
} }
l.Debug.Printf("%d: returning", messageID) l.Debug.Printf("%d: returning", msgCtx.id)
return nil return nil
} }
...@@ -16,13 +16,21 @@ const ( ...@@ -16,13 +16,21 @@ const (
passwordModifyOID = "1.3.6.1.4.1.4203.1.11.1" passwordModifyOID = "1.3.6.1.4.1.4203.1.11.1"
) )
// PasswordModifyRequest implements the Password Modify Extended Operation as defined in https://www.ietf.org/rfc/rfc3062.txt
type PasswordModifyRequest struct { type PasswordModifyRequest struct {
// UserIdentity is an optional string representation of the user associated with the request.
// This string may or may not be an LDAPDN [RFC2253].
// If no UserIdentity field is present, the request acts up upon the password of the user currently associated with the LDAP session
UserIdentity string UserIdentity string
// OldPassword, if present, contains the user's current password
OldPassword string OldPassword string
// NewPassword, if present, contains the desired password for this user
NewPassword string NewPassword string
} }
// PasswordModifyResult holds the server response to a PasswordModifyRequest
type PasswordModifyResult struct { type PasswordModifyResult struct {
// GeneratedPassword holds a password generated by the server, if present
GeneratedPassword string GeneratedPassword string
} }
...@@ -47,7 +55,7 @@ func (r *PasswordModifyRequest) encode() (*ber.Packet, error) { ...@@ -47,7 +55,7 @@ func (r *PasswordModifyRequest) encode() (*ber.Packet, error) {
return request, nil return request, nil
} }
// Create a new PasswordModifyRequest // NewPasswordModifyRequest creates a new PasswordModifyRequest
// //
// According to the RFC 3602: // According to the RFC 3602:
// userIdentity is a string representing the user associated with the request. // userIdentity is a string representing the user associated with the request.
...@@ -72,11 +80,10 @@ func NewPasswordModifyRequest(userIdentity string, oldPassword string, newPasswo ...@@ -72,11 +80,10 @@ func NewPasswordModifyRequest(userIdentity string, oldPassword string, newPasswo
} }
} }
// PasswordModify performs the modification request
func (l *Conn) PasswordModify(passwordModifyRequest *PasswordModifyRequest) (*PasswordModifyResult, error) { func (l *Conn) PasswordModify(passwordModifyRequest *PasswordModifyRequest) (*PasswordModifyResult, error) {
messageID := l.nextMessageID()
packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request") packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request")
packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "MessageID")) packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID"))
encodedPasswordModifyRequest, err := passwordModifyRequest.encode() encodedPasswordModifyRequest, err := passwordModifyRequest.encode()
if err != nil { if err != nil {
...@@ -86,24 +93,21 @@ func (l *Conn) PasswordModify(passwordModifyRequest *PasswordModifyRequest) (*Pa ...@@ -86,24 +93,21 @@ func (l *Conn) PasswordModify(passwordModifyRequest *PasswordModifyRequest) (*Pa
l.Debug.PrintPacket(packet) l.Debug.PrintPacket(packet)
channel, err := l.sendMessage(packet) msgCtx, err := l.sendMessage(packet)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if channel == nil { defer l.finishMessage(msgCtx)
return nil, NewError(ErrorNetwork, errors.New("ldap: could not send message"))
}
defer l.finishMessage(messageID)
result := &PasswordModifyResult{} result := &PasswordModifyResult{}
l.Debug.Printf("%d: waiting for response", messageID) l.Debug.Printf("%d: waiting for response", msgCtx.id)
packetResponse, ok := <-channel packetResponse, ok := <-msgCtx.responses
if !ok { if !ok {
return nil, NewError(ErrorNetwork, errors.New("ldap: channel closed")) return nil, NewError(ErrorNetwork, errors.New("ldap: response channel closed"))
} }
packet, err = packetResponse.ReadPacket() packet, err = packetResponse.ReadPacket()
l.Debug.Printf("%d: got response %p", messageID, packet) l.Debug.Printf("%d: got response %p", msgCtx.id, packet)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -131,10 +135,10 @@ func (l *Conn) PasswordModify(passwordModifyRequest *PasswordModifyRequest) (*Pa ...@@ -131,10 +135,10 @@ func (l *Conn) PasswordModify(passwordModifyRequest *PasswordModifyRequest) (*Pa
extendedResponse := packet.Children[1] extendedResponse := packet.Children[1]
for _, child := range extendedResponse.Children { for _, child := range extendedResponse.Children {
if child.Tag == 11 { if child.Tag == 11 {
passwordModifyReponseValue := ber.DecodePacket(child.Data.Bytes()) passwordModifyResponseValue := ber.DecodePacket(child.Data.Bytes())
if len(passwordModifyReponseValue.Children) == 1 { if len(passwordModifyResponseValue.Children) == 1 {
if passwordModifyReponseValue.Children[0].Tag == 0 { if passwordModifyResponseValue.Children[0].Tag == 0 {
result.GeneratedPassword = ber.DecodeString(passwordModifyReponseValue.Children[0].Data.Bytes()) result.GeneratedPassword = ber.DecodeString(passwordModifyResponseValue.Children[0].Data.Bytes())
} }
} }
} }
......
...@@ -68,18 +68,21 @@ import ( ...@@ -68,18 +68,21 @@ import (
"gopkg.in/asn1-ber.v1" "gopkg.in/asn1-ber.v1"
) )
// scope choices
const ( const (
ScopeBaseObject = 0 ScopeBaseObject = 0
ScopeSingleLevel = 1 ScopeSingleLevel = 1
ScopeWholeSubtree = 2 ScopeWholeSubtree = 2
) )
// ScopeMap contains human readable descriptions of scope choices
var ScopeMap = map[int]string{ var ScopeMap = map[int]string{
ScopeBaseObject: "Base Object", ScopeBaseObject: "Base Object",
ScopeSingleLevel: "Single Level", ScopeSingleLevel: "Single Level",
ScopeWholeSubtree: "Whole Subtree", ScopeWholeSubtree: "Whole Subtree",
} }
// derefAliases
const ( const (
NeverDerefAliases = 0 NeverDerefAliases = 0
DerefInSearching = 1 DerefInSearching = 1
...@@ -87,6 +90,7 @@ const ( ...@@ -87,6 +90,7 @@ const (
DerefAlways = 3 DerefAlways = 3
) )
// DerefMap contains human readable descriptions of derefAliases choices
var DerefMap = map[int]string{ var DerefMap = map[int]string{
NeverDerefAliases: "NeverDerefAliases", NeverDerefAliases: "NeverDerefAliases",
DerefInSearching: "DerefInSearching", DerefInSearching: "DerefInSearching",
...@@ -114,11 +118,15 @@ func NewEntry(dn string, attributes map[string][]string) *Entry { ...@@ -114,11 +118,15 @@ func NewEntry(dn string, attributes map[string][]string) *Entry {
} }
} }
// Entry represents a single search result entry
type Entry struct { type Entry struct {
// DN is the distinguished name of the entry
DN string DN string
// Attributes are the returned attributes for the entry
Attributes []*EntryAttribute Attributes []*EntryAttribute
} }
// GetAttributeValues returns the values for the named attribute, or an empty list
func (e *Entry) GetAttributeValues(attribute string) []string { func (e *Entry) GetAttributeValues(attribute string) []string {
for _, attr := range e.Attributes { for _, attr := range e.Attributes {
if attr.Name == attribute { if attr.Name == attribute {
...@@ -128,6 +136,7 @@ func (e *Entry) GetAttributeValues(attribute string) []string { ...@@ -128,6 +136,7 @@ func (e *Entry) GetAttributeValues(attribute string) []string {
return []string{} return []string{}
} }
// GetRawAttributeValues returns the byte values for the named attribute, or an empty list
func (e *Entry) GetRawAttributeValues(attribute string) [][]byte { func (e *Entry) GetRawAttributeValues(attribute string) [][]byte {
for _, attr := range e.Attributes { for _, attr := range e.Attributes {
if attr.Name == attribute { if attr.Name == attribute {
...@@ -137,6 +146,7 @@ func (e *Entry) GetRawAttributeValues(attribute string) [][]byte { ...@@ -137,6 +146,7 @@ func (e *Entry) GetRawAttributeValues(attribute string) [][]byte {
return [][]byte{} return [][]byte{}
} }
// GetAttributeValue returns the first value for the named attribute, or ""
func (e *Entry) GetAttributeValue(attribute string) string { func (e *Entry) GetAttributeValue(attribute string) string {
values := e.GetAttributeValues(attribute) values := e.GetAttributeValues(attribute)
if len(values) == 0 { if len(values) == 0 {
...@@ -145,6 +155,7 @@ func (e *Entry) GetAttributeValue(attribute string) string { ...@@ -145,6 +155,7 @@ func (e *Entry) GetAttributeValue(attribute string) string {
return values[0] return values[0]
} }
// GetRawAttributeValue returns the first value for the named attribute, or an empty slice
func (e *Entry) GetRawAttributeValue(attribute string) []byte { func (e *Entry) GetRawAttributeValue(attribute string) []byte {
values := e.GetRawAttributeValues(attribute) values := e.GetRawAttributeValues(attribute)
if len(values) == 0 { if len(values) == 0 {
...@@ -153,6 +164,7 @@ func (e *Entry) GetRawAttributeValue(attribute string) []byte { ...@@ -153,6 +164,7 @@ func (e *Entry) GetRawAttributeValue(attribute string) []byte {
return values[0] return values[0]
} }
// Print outputs a human-readable description
func (e *Entry) Print() { func (e *Entry) Print() {
fmt.Printf("DN: %s\n", e.DN) fmt.Printf("DN: %s\n", e.DN)
for _, attr := range e.Attributes { for _, attr := range e.Attributes {
...@@ -160,6 +172,7 @@ func (e *Entry) Print() { ...@@ -160,6 +172,7 @@ func (e *Entry) Print() {
} }
} }
// PrettyPrint outputs a human-readable description indenting
func (e *Entry) PrettyPrint(indent int) { func (e *Entry) PrettyPrint(indent int) {
fmt.Printf("%sDN: %s\n", strings.Repeat(" ", indent), e.DN) fmt.Printf("%sDN: %s\n", strings.Repeat(" ", indent), e.DN)
for _, attr := range e.Attributes { for _, attr := range e.Attributes {
...@@ -180,38 +193,51 @@ func NewEntryAttribute(name string, values []string) *EntryAttribute { ...@@ -180,38 +193,51 @@ func NewEntryAttribute(name string, values []string) *EntryAttribute {
} }
} }
// EntryAttribute holds a single attribute
type EntryAttribute struct { type EntryAttribute struct {
// Name is the name of the attribute
Name string Name string
// Values contain the string values of the attribute
Values []string Values []string
// ByteValues contain the raw values of the attribute
ByteValues [][]byte ByteValues [][]byte
} }
// Print outputs a human-readable description
func (e *EntryAttribute) Print() { func (e *EntryAttribute) Print() {
fmt.Printf("%s: %s\n", e.Name, e.Values) fmt.Printf("%s: %s\n", e.Name, e.Values)
} }
// PrettyPrint outputs a human-readable description with indenting
func (e *EntryAttribute) PrettyPrint(indent int) { func (e *EntryAttribute) PrettyPrint(indent int) {
fmt.Printf("%s%s: %s\n", strings.Repeat(" ", indent), e.Name, e.Values) fmt.Printf("%s%s: %s\n", strings.Repeat(" ", indent), e.Name, e.Values)
} }
// SearchResult holds the server's response to a search request
type SearchResult struct { type SearchResult struct {
// Entries are the returned entries
Entries []*Entry Entries []*Entry
// Referrals are the returned referrals
Referrals []string Referrals []string
// Controls are the returned controls
Controls []Control Controls []Control
} }
// Print outputs a human-readable description
func (s *SearchResult) Print() { func (s *SearchResult) Print() {
for _, entry := range s.Entries { for _, entry := range s.Entries {
entry.Print() entry.Print()
} }
} }
// PrettyPrint outputs a human-readable description with indenting
func (s *SearchResult) PrettyPrint(indent int) { func (s *SearchResult) PrettyPrint(indent int) {
for _, entry := range s.Entries { for _, entry := range s.Entries {
entry.PrettyPrint(indent) entry.PrettyPrint(indent)
} }
} }
// SearchRequest represents a search request to send to the server
type SearchRequest struct { type SearchRequest struct {
BaseDN string BaseDN string
Scope int Scope int
...@@ -247,6 +273,7 @@ func (s *SearchRequest) encode() (*ber.Packet, error) { ...@@ -247,6 +273,7 @@ func (s *SearchRequest) encode() (*ber.Packet, error) {
return request, nil return request, nil
} }
// NewSearchRequest creates a new search request
func NewSearchRequest( func NewSearchRequest(
BaseDN string, BaseDN string,
Scope, DerefAliases, SizeLimit, TimeLimit int, Scope, DerefAliases, SizeLimit, TimeLimit int,
...@@ -341,10 +368,10 @@ func (l *Conn) SearchWithPaging(searchRequest *SearchRequest, pagingSize uint32) ...@@ -341,10 +368,10 @@ func (l *Conn) SearchWithPaging(searchRequest *SearchRequest, pagingSize uint32)
return searchResult, nil return searchResult, nil
} }
// Search performs the given search request
func (l *Conn) Search(searchRequest *SearchRequest) (*SearchResult, error) { func (l *Conn) Search(searchRequest *SearchRequest) (*SearchResult, error) {
messageID := l.nextMessageID()
packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request") packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request")
packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "MessageID")) packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID"))
// encode search request // encode search request
encodedSearchRequest, err := searchRequest.encode() encodedSearchRequest, err := searchRequest.encode()
if err != nil { if err != nil {
...@@ -358,14 +385,11 @@ func (l *Conn) Search(searchRequest *SearchRequest) (*SearchResult, error) { ...@@ -358,14 +385,11 @@ func (l *Conn) Search(searchRequest *SearchRequest) (*SearchResult, error) {
l.Debug.PrintPacket(packet) l.Debug.PrintPacket(packet)
channel, err := l.sendMessage(packet) msgCtx, err := l.sendMessage(packet)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if channel == nil { defer l.finishMessage(msgCtx)
return nil, NewError(ErrorNetwork, errors.New("ldap: could not send message"))
}
defer l.finishMessage(messageID)
result := &SearchResult{ result := &SearchResult{
Entries: make([]*Entry, 0), Entries: make([]*Entry, 0),
...@@ -374,13 +398,13 @@ func (l *Conn) Search(searchRequest *SearchRequest) (*SearchResult, error) { ...@@ -374,13 +398,13 @@ func (l *Conn) Search(searchRequest *SearchRequest) (*SearchResult, error) {
foundSearchResultDone := false foundSearchResultDone := false
for !foundSearchResultDone { for !foundSearchResultDone {
l.Debug.Printf("%d: waiting for response", messageID) l.Debug.Printf("%d: waiting for response", msgCtx.id)
packetResponse, ok := <-channel packetResponse, ok := <-msgCtx.responses
if !ok { if !ok {
return nil, NewError(ErrorNetwork, errors.New("ldap: channel closed")) return nil, NewError(ErrorNetwork, errors.New("ldap: response channel closed"))
} }
packet, err = packetResponse.ReadPacket() packet, err = packetResponse.ReadPacket()
l.Debug.Printf("%d: got response %p", messageID, packet) l.Debug.Printf("%d: got response %p", msgCtx.id, packet)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -421,6 +445,6 @@ func (l *Conn) Search(searchRequest *SearchRequest) (*SearchResult, error) { ...@@ -421,6 +445,6 @@ func (l *Conn) Search(searchRequest *SearchRequest) (*SearchResult, error) {
result.Referrals = append(result.Referrals, packet.Children[1].Children[0].Value.(string)) result.Referrals = append(result.Referrals, packet.Children[1].Children[0].Value.(string))
} }
} }
l.Debug.Printf("%d: returning", messageID) l.Debug.Printf("%d: returning", msgCtx.id)
return result, nil return result, nil
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment