Index: if_vxn.c =================================================================== RCS file: if_vxn.c diff -N if_vxn.c --- /dev/null 1 Jan 1970 00:00:00 -0000 +++ if_vxn.c 28 Apr 2021 01:00:19 -0000 @@ -0,0 +1,2419 @@ +/* $OpenBSD$ */ + +/* + * Copyright (c) 2021 David Gwynne + * Copyright (C) 2015-2020 Jason A. Donenfeld . All Rights Rese +rved. + * Copyright (C) 2019-2020 Matt Dunwoodie + * + * Permission to use, copy, modify, and distribute this software for any + * purpose with or without fee is hereby granted, provided that the above + * copyright notice and this permission notice appear in all copies. + * + * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + */ + +#include "bpfilter.h" +#include "pf.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#ifdef INET6 +#include +#include +#include +#endif + +#ifdef MPLS +#include +#endif /* MPLS */ + +#include +#include + +#include +#include + +#if NBPFILTER > 0 +#include +#endif + +#if NPF > 0 +#include +#endif + +/* + * Virtual encrypted networks + */ + +/* The protocol */ + +struct vxnt_header { + uint32_t vxnt_type; +#define VXNT_INIT (1U << 24) +#define VXNT_RESPONSE (2U << 24) +#define VXNT_COOKIE (3U << 24) +#define VXNT_DATA (4U << 24) +}; + +struct vxnt_header_init { + uint32_t vxnti_type; + uint32_t vxnti_sindex; + uint8_t vxnti_ue[NOISE_PUBLIC_KEY_LEN]; + uint8_t vxnti_es[NOISE_PUBLIC_KEY_LEN + + NOISE_AUTHTAG_LEN]; + uint8_t vxnti_ets[NOISE_TIMESTAMP_LEN + + NOISE_AUTHTAG_LEN]; + struct cookie_macs vxnti_mac; +}; + +struct vxnt_header_response { + uint32_t vxntr_type; + uint32_t vxntr_sindex; + uint32_t vxntr_rindex; + uint8_t vxntr_ue[NOISE_PUBLIC_KEY_LEN]; + uint8_t vxntr_en[0 + NOISE_AUTHTAG_LEN]; + struct cookie_macs vxntr_mac; +}; + +struct vxnt_header_cookie { + uint32_t vxntc_type; + uint32_t vxntc_rindex; + uint8_t vxntc_nonce[COOKIE_NONCE_SIZE]; + uint8_t vxntc_ec[COOKIE_ENCRYPTED_SIZE]; +}; + +struct vxnt_header_data { + uint32_t vxntd_type; + uint32_t vxntd_rindex; + uint32_t vxntd_seq_hi; + uint32_t vxntd_seq_lo; + uint8_t vxntd_mac[NOISE_AUTHTAG_LEN]; +}; + +/* + * vxn tunnel header + * + * Multiple networks are transported over a single vxn connection + * between two endpoints. The encrypted payload contains a vxn + * tunnel prefix followed by the actual network packet. This prefix + * identifies the type of network packet (IPv4, IPv6, MPLS, Ethernet, + * etc), the virtual network, and tunnel options data associated + * with the packet. The tunnel options are deliberately compatible + * with those in RFC 8926 Geneve: Generic Network Virtualization + * Encapsulation. Note that the vxn tunnel header is different from + * the Geneve header. + */ + +struct vxn_header { + uint8_t vxn_ver_prio; +#define VXNT_VER_SHIFT 4 +#define VXNT_VER_MASK (0xf << VXNT_VER_SHIFT) +#define VXNT_VER_0 (0x0 << VXNT_VER_SHIFT) +#define VXNT_PAD_SHIFT 3 +#define VXNT_PAD (0x1 << VXNT_PAD_SHIFT) +#define VXNT_PRIO_SHIFT 0 +#define VXNT_PRIO_MASK (0x7 << VXNT_PRIO_SHIFT) + uint8_t vxn_opts; +#define VXNT_OPTLEN_SHIFT 2 +#define VXNT_OPTLEN_MASK (0xfc << VXNT_OPTLEN_SHIFT) +#define VXNT_OPTLEN_UNIT 4U +#define VXNT_F_OAM (0x01 << 1) +#define VXNT_F_CRITICAL (0x01 << 0) + uint16_t vxn_proto; + uint32_t vxn_vni; +}; + +/* The vxn_t_option header is compatible with geneve */ + +struct vxn_option { + uint16_t vxno_class; + uint8_t vxno_type; +#define VXNT_OPTION_C (1U << 7) +#define VXNT_OPTION_TYPE_MASK 0x7fU + uint8_t vxno_flags; +#define VXNT_OPTION_LEN_SHIFT 0 +#define VXNT_OPTION_LEN_MASK (0x1fU << GENEVE_OPTION_LEN_SHIFT) +#define VXNT_OPTION_LEN_UNIT 4U +}; + +/* + * The driver. + */ + +#define VXNT_IFTYPE IFT_DUMMY + +union vxnt_addr { + struct in_addr in4; + struct in6_addr in6; +}; + +struct vxnt_softc; + +struct vxnt_index { + RBT_ENTRY(vxnt_index) vi_entry; + uint32_t vi_index; + struct vxnt_softc *vi_vxnt_smr; +}; + +RBT_HEAD(vxnt_indexes, vxnt_index); + +struct vxntep { + TAILQ_ENTRY(vxntep) vxntep_entry; + + sa_family_t vxntep_af; + unsigned int vxntep_rdomain; + union vxnt_addr vxntep_addr; +#define vxntep_addr4 vxntep_addr.in4 +#define vxntep_addr6 vxntep_addr.in6 + in_port_t vxntep_port; + + struct socket *vxntep_so; + + struct mutex vxntep_mtx; + struct vxnt_indexes vxntep_indexes; + + struct mbuf_queue vxntep_hs_mq; + struct task vxntep_hs_task; +}; + +TAILQ_HEAD(vxnteps, vxntep); + +struct vxn_key { + RBT_ENTRY(vxn_key) vk_entry; + uint32_t vk_vnetid; + uint16_t vk_proto; +}; + +RBT_HEAD(vxn_keys, vxn_key); + +struct vxnt_softc { + struct ifnet sct_if; + struct vxnt_index *sct_index; + struct noise_remote sct_remote; + + /* configured addresses */ + unsigned int sct_rdomain; + sa_family_t sct_af; + + union vxnt_addr sct_saddr; + union vxnt_addr sct_daddr; + in_port_t sct_sport; + in_port_t sct_dport; + unsigned int sct_mode; +#define VNXT_MODE_HUB 0 +#define VNXT_MODE_SPOKE 1 + + /* addresses on the wire */ + union vxnt_addr sct_laddr; + union vxnt_addr sct_faddr; + in_port_t sct_lport; + in_port_t sct_fport; + + struct vxntep *sct_vxntep; + + uint16_t sct_df; + int sct_ttl; + + struct mutex sct_if_mtx; + struct vxn_keys sct_if_keys; +}; + +struct vxn_softc { + struct vxn_key sc_key; /* must be first */ + struct ifnet *sc_ifp; + void (*sc_transmit)(struct vxn_softc *, struct ifnet *, struct mbuf *); + + unsigned int sc_if_index0; + struct task sc_dhook; + struct task sc_lhook; + + unsigned int sc_dead; + int sc_txhprio; + int sc_rxhprio; + + struct refcnt sc_refs; +}; + +void vxnattach(int); + +static int vxnt_clone_create(struct if_clone *, int); +static int vxnt_clone_destroy(struct ifnet *); + +static int vxnt_output(struct ifnet *, struct mbuf *, + struct sockaddr *, struct rtentry *); +static int vxnt_enqueue(struct ifnet *, struct mbuf *); +static void vxnt_start(struct ifqueue *); +static void vxnt_send(struct vxnt_softc *, struct mbuf_list *); + +static int vxnt_ioctl(struct ifnet *, u_long, caddr_t); +static int vxnt_up(struct vxnt_softc *); +static int vxnt_down(struct vxnt_softc *); + +static struct mbuf * + vxnt_input(void *, struct mbuf *, + struct ip *, struct ip6_hdr *, void *, int); +static struct mbuf * + vxnt_encrypt(struct vxnt_softc *, struct noise_keypair *, + struct mbuf *, uint64_t); + +static int vxnt_set_rdomain(struct vxnt_softc *, const struct ifreq *); +static int vxnt_get_rdomain(struct vxnt_softc *, struct ifreq *); +static int vxnt_set_tunnel(struct vxnt_softc *, + const struct if_laddrreq *); +static int vxnt_get_tunnel(struct vxnt_softc *, struct if_laddrreq *); +static int vxnt_del_tunnel(struct vxnt_softc *); +static int vxnt_set_vnetid(struct vxnt_softc *, const struct ifreq *); +static int vxnt_get_vnetid(struct vxnt_softc *, struct ifreq *); + +static void vxn_detach_hook(void *); +static void vxn_link_hook(void *); + +static int vxn_up(struct vxn_softc *); +static int vxn_down(struct vxn_softc *); +static void vxn_input(struct vxnt_softc *, struct mbuf *); +static int vxn_enqueue(struct ifnet *, struct mbuf *); +static void vxn_start_transmit(struct vxn_softc *, struct ifnet *, + struct mbuf *); +static void vxn_encap(struct vxn_softc *, struct ifnet *, struct mbuf *, + uint8_t, uint16_t); + +static int vxn_if_running(struct ifnet *); + +static int vxn_set_parent(struct vxn_softc *, const struct if_parent *); +static int vxn_get_parent(struct vxn_softc *, struct if_parent *); +static int vxn_del_parent(struct vxn_softc *); + +static int vxn_set_vnetid(struct vxn_softc *, const struct ifreq *); +static int vxn_get_vnetid(struct vxn_softc *, struct ifreq *); + +static int vxn_clone_create(struct if_clone *, int); +static int vxn_clone_destroy(struct ifnet *); +static int vxn_ioctl(struct ifnet *, u_long, caddr_t); +static int vxn_output(struct ifnet *, struct mbuf *, struct sockaddr *, + struct rtentry *); +static void vxn_transmit(struct vxn_softc *, struct ifnet *, + struct mbuf *); +static void vxn_start(struct ifqueue *); + +static int evxn_clone_create(struct if_clone *, int); +static int evxn_clone_destroy(struct ifnet *); +static int evxn_ioctl(struct ifnet *, u_long, caddr_t); +static void evxn_transmit(struct vxn_softc *, struct ifnet *, + struct mbuf *); +static void evxn_start(struct ifqueue *); + +static struct if_clone vxnt_cloner = + IF_CLONE_INITIALIZER("vxnt", vxnt_clone_create, vxnt_clone_destroy); + +static struct if_clone vxn_cloner = + IF_CLONE_INITIALIZER("vxn", vxn_clone_create, vxn_clone_destroy); +static struct if_clone evxn_cloner = + IF_CLONE_INITIALIZER("evxn", evxn_clone_create, evxn_clone_destroy); + +static struct rwlock vxnt_lock = RWLOCK_INITIALIZER("vxnteps"); +static struct vxnteps vxnteps = TAILQ_HEAD_INITIALIZER(vxnteps); +static struct taskq *vxntq; + +RBT_PROTOTYPE(vxnt_indexes, vxnt_index, vi_entry, vxnt_index_cmp); +RBT_PROTOTYPE(vxn_keys, vxn_key, vk_entry, vxn_key_cmp); + +struct noise_mac { + uint32_t words[NOISE_AUTHTAG_LEN / sizeof(uint32_t)]; +}; +CTASSERT(sizeof(struct noise_mac) == NOISE_AUTHTAG_LEN); + +void +vxnattach(int count) +{ + if_clone_attach(&vxnt_cloner); +} + +static int +vxnt_clone_create(struct if_clone *ifc, int unit) +{ + struct vxnt_softc *sct; + struct ifnet *ifp0; + + if (vxntq == NULL) { + rw_enter_write(&vxnt_lock); /* borrow a lock */ + if (vxntq == NULL) { + vxntq = taskq_create("vxntq", 1, IPL_SOFTNET, + TASKQ_MPSAFE); + } + rw_exit_write(&vxnt_lock); + } + + sct = malloc(sizeof(*sct), M_DEVBUF, M_WAITOK|M_ZERO|M_CANFAIL); + if (sct == NULL) + return (ENOMEM); + + ifp0 = &sct->sct_if; + + snprintf(ifp0->if_xname, sizeof(ifp0->if_xname), "%s%d", + ifc->ifc_name, unit); + + sct->sct_af = AF_UNSPEC; + sct->sct_df = 0; + sct->sct_ttl = ip_defttl; + + ifp0->if_softc = sct; + ifp0->if_hardmtu = 32768; /* XXX */ + ifp0->if_mtu = 0; + ifp0->if_ioctl = vxnt_ioctl; + ifp0->if_output = vxnt_output; + ifp0->if_enqueue = vxnt_enqueue; + ifp0->if_qstart = vxnt_start; + ifp0->if_flags = IFF_POINTOPOINT; + ifp0->if_xflags = IFXF_CLONED | IFXF_MPSAFE; + + if_counters_alloc(ifp0); + if_attach(ifp0); + + return (0); +} + +static int +vxnt_clone_destroy(struct ifnet *ifp0) +{ + struct vxnt_softc *sct = ifp0->if_softc; + + NET_LOCK(); + if (ISSET(ifp0->if_flags, IFF_RUNNING)) + vxnt_down(sct); + NET_UNLOCK(); + + if_detach(ifp0); + + free(sct, M_DEVBUF, sizeof(*sct)); + + return (0); +} + +static void +vxn_softc_init(struct vxn_softc *sc, struct ifnet *ifp, uint16_t proto, + void (*transmit)(struct vxn_softc *, struct ifnet *, struct mbuf *)) +{ + sc->sc_ifp = ifp; + sc->sc_key.vk_proto = proto; + sc->sc_key.vk_vnetid = htonl(0); + sc->sc_transmit = transmit; + + task_set(&sc->sc_dhook, vxn_detach_hook, sc); + task_set(&sc->sc_lhook, vxn_link_hook, sc); + + sc->sc_txhprio = IF_HDRPRIO_PACKET; + sc->sc_rxhprio = IF_HDRPRIO_OUTER; + + refcnt_init(&sc->sc_refs); +} + +static inline void +vxn_rele(struct vxn_softc *sc) +{ + refcnt_rele_wake(&sc->sc_refs); +} + +static int +vxn_clone_create(struct if_clone *ifc, int unit) +{ + struct vxn_softc *sc; + struct ifnet *ifp; + + sc = malloc(sizeof(*sc), M_DEVBUF, M_WAITOK|M_ZERO|M_CANFAIL); + if (sc == NULL) + return (ENOMEM); + + ifp = malloc(sizeof(*ifp), M_DEVBUF, M_WAITOK|M_ZERO|M_CANFAIL); + if (ifp == NULL) { + free(sc, M_DEVBUF, sizeof(*sc)); + return (ENOMEM); + } + + snprintf(ifp->if_xname, sizeof(ifp->if_xname), "%s%d", + ifc->ifc_name, unit); + + vxn_softc_init(sc, ifp, 0, vxn_transmit); + + ifp->if_softc = sc; + ifp->if_hardmtu = 32768; /* XXX */ + ifp->if_mtu = 1280; + ifp->if_ioctl = vxn_ioctl; + ifp->if_rtrequest = p2p_rtrequest; + ifp->if_output = vxn_output; + ifp->if_enqueue = vxn_enqueue; + ifp->if_qstart = vxn_start; + ifp->if_flags = IFF_POINTOPOINT | IFF_MULTICAST; + ifp->if_xflags = IFXF_CLONED | IFXF_MPSAFE; + + if_counters_alloc(ifp); + if_attach(ifp); + + return (0); +} + +static int +vxn_clone_destroy(struct ifnet *ifp) +{ + struct vxn_softc *sc = ifp->if_softc; + + NET_LOCK(); + sc->sc_dead = 1; + + if (ISSET(ifp->if_flags, IFF_RUNNING)) + vxn_down(sc); + NET_UNLOCK(); + + if_detach(ifp); + + free(sc, M_DEVBUF, sizeof(*sc)); + free(ifp, M_DEVBUF, sizeof(*ifp)); + + return (0); +} + +static int +evxn_clone_create(struct if_clone *ifc, int unit) +{ + struct vxn_softc *sc; + struct arpcom *ac; + struct ifnet *ifp; + + sc = malloc(sizeof(*sc), M_DEVBUF, M_WAITOK|M_ZERO|M_CANFAIL); + if (sc == NULL) + return (ENOMEM); + + ac = malloc(sizeof(*ac), M_DEVBUF, M_WAITOK|M_ZERO|M_CANFAIL); + if (ifp == NULL) { + free(sc, M_DEVBUF, sizeof(*sc)); + return (ENOMEM); + } + ifp = &ac->ac_if; + + snprintf(ifp->if_xname, sizeof(ifp->if_xname), "%s%d", + ifc->ifc_name, unit); + + vxn_softc_init(sc, ifp, htons(ETHERTYPE_TRANSETHER), evxn_transmit); + + ifp->if_softc = sc; + ifp->if_hardmtu = 32768; /* XXX */ + ifp->if_ioctl = evxn_ioctl; + ifp->if_enqueue = vxn_enqueue; + ifp->if_qstart = evxn_start; + ifp->if_flags = IFF_BROADCAST | IFF_SIMPLEX | IFF_MULTICAST; + ifp->if_xflags = IFXF_CLONED | IFXF_MPSAFE; + ether_fakeaddr(ifp); + + if_counters_alloc(ifp); + if_attach(ifp); + ether_ifattach(ifp); + + return (0); +} + +static int +evxn_clone_destroy(struct ifnet *ifp) +{ + struct vxn_softc *sc = ifp->if_softc; + struct arpcom *ac = (struct arpcom *)ifp; + + NET_LOCK(); + sc->sc_dead = 1; + if (ISSET(ifp->if_flags, IFF_RUNNING)) + vxn_down(sc); + NET_UNLOCK(); + + if_detach(ifp); + ether_ifdetach(ifp); + + free(sc, M_DEVBUF, sizeof(*sc)); + free(ac, M_DEVBUF, sizeof(*ac)); + + return (0); +} + +static int +vxn_if_running(struct ifnet *ifp) +{ + return (ifp != NULL && + ISSET(ifp->if_flags, IFF_RUNNING) && + LINK_STATE_IS_UP(ifp->if_link_state)); +} + +static struct mbuf * +vxnt_encap_ipv4(struct vxnt_softc *sct, struct mbuf *m, + const union vxnt_addr *src, const union vxnt_addr *dst, uint8_t tos) +{ + struct ip *ip; + + m = m_prepend(m, sizeof(*ip), M_DONTWAIT); + if (m == NULL) + return (NULL); + + ip = mtod(m, struct ip *); + ip->ip_v = IPVERSION; + ip->ip_hl = sizeof(*ip) >> 2; + ip->ip_off = sct->sct_df; + ip->ip_tos = tos; + ip->ip_len = htons(m->m_pkthdr.len); + ip->ip_ttl = sct->sct_ttl; + ip->ip_p = IPPROTO_UDP; + ip->ip_src = src->in4; + ip->ip_dst = dst->in4; + + return (m); +} + +#ifdef INET6 +static struct mbuf * +vxnt_encap_ipv6(struct vxnt_softc *sct, struct mbuf *m, + const union vxnt_addr *src, const union vxnt_addr *dst, uint8_t tos) +{ + struct ip6_hdr *ip6; + int len = m->m_pkthdr.len; + + m = m_prepend(m, sizeof(*ip6), M_DONTWAIT); + if (m == NULL) + return (NULL); + + ip6 = mtod(m, struct ip6_hdr *); + ip6->ip6_flow = 0; + ip6->ip6_vfc |= IPV6_VERSION; + ip6->ip6_flow |= htonl((uint32_t)tos << 20); + ip6->ip6_plen = htons(len); + ip6->ip6_nxt = IPPROTO_UDP; + ip6->ip6_hlim = sct->sct_ttl; + ip6->ip6_src = src->in6; + ip6->ip6_dst = dst->in6; + + if (sct->sct_df) + SET(m->m_pkthdr.csum_flags, M_IPV6_DF_OUT); + + return (m); +} +#endif /* INET6 */ + +static void +vxnt_handshake_input(struct vxntep *vxntep, struct mbuf *m, int hlen, + size_t mlen) +{ + int plen = hlen + mlen; + + if (m->m_pkthdr.len < plen) { + m_freem(m); + return; + } + + if (mq_push(&vxntep->vxntep_hs_mq, m) == 0) + task_add(vxntq, &vxntep->vxntep_hs_task); +} + +static inline struct vxnt_softc * +vxnt_find(struct vxntep *vxntep, uint32_t index) +{ + struct vxnt_index key = { .vi_index = index }; + struct vxnt_index *vi; + struct vxnt_softc *sct = NULL; + + mtx_enter(&vxntep->vxntep_mtx); + vi = RBT_FIND(vxnt_indexes, &vxntep->vxntep_indexes, &key); + if (vi != NULL) + sct = SMR_PTR_GET(&vi->vi_vxnt_smr); + mtx_leave(&vxntep->vxntep_mtx); + + return (sct); +} + +static int +vxnt_addr_eq(const union vxnt_addr *vaa, const union vxnt_addr *vab) +{ + size_t i; + + for (i = 0; i < nitems(vaa->in6.s6_addr32); i++) { + if (vaa->in6.s6_addr32[i] > vab->in6.s6_addr32[i]) + return (1); + if (vaa->in6.s6_addr32[i] < vab->in6.s6_addr32[i]) + return (-1); + } + + return (0); +} + +static int +vxnt_mac_eq(const void *a, const void *b) +{ + const struct noise_mac *maca = a, *macb = b; + size_t i; + uint32_t diff = 0; + + for (i = 0; i < nitems(maca->words); i++) + diff |= maca->words[i] ^ macb->words[i]; + + return (diff == 0); +} + +static int +vxnt_input_verify(chacha_ctx *chacha_ctx, + const uint8_t key[CHACHA20POLY1305_KEY_SIZE], uint64_t nonce, + struct mbuf *m, int vxnthlen, const struct vxnt_header_data *vxntd) +{ + poly1305_state poly1305_ctx; + uint8_t block0[CHACHA20POLY1305_KEY_SIZE]; + struct noise_mac mac; + uint64_t datalen; + uint64_t lens[2]; + + chacha_keysetup(chacha_ctx, key, CHACHA20POLY1305_KEY_SIZE * 8); + + /* + * AEAD-ChaCha20-Poly1305 uses the IETF construction with the 96bit + * nonce and 32 bit counter starting from 0. Noise uses this AEAD + * with the counter off the wire as the nonce. That nonce is padded + * with zeros on the left, so the layout of the chacha state is + * predictable. Set it up directly, rather than stuff bits just so + * they can be unstuffed straight away. + */ + chacha_ctx->input[12] = 0; /* counter starts at 0 */ + chacha_ctx->input[13] = 0; /* nonce padding is 0 */ + chacha_ctx->input[14] = nonce; /* the rest of the "nonce" */ + chacha_ctx->input[15] = nonce >> 32; + + /* AEAD-ChaCha20-Poly1305 uses the first block for the poly key */ + bzero(block0, sizeof(block0)); + chacha_encrypt_bytes(chacha_ctx, block0, block0, sizeof(block0)); + poly1305_init(&poly1305_ctx, block0); + explicit_bzero(block0, sizeof(block0)); + + /* there's no aad */ + + /* feed the data in */ + datalen = m->m_len - vxnthlen; + poly1305_update(&poly1305_ctx, mtod(m, uint8_t *) + vxnthlen, datalen); + while ((m = m->m_next) != NULL) { + datalen += m->m_len; + poly1305_update(&poly1305_ctx, mtod(m, uint8_t *), m->m_len); + } + + /* block0 is full of zeros now, so we can use it for padding */ + poly1305_update(&poly1305_ctx, block0, (0x10 - datalen) & 0xf); + + lens[0] = htole64(0); /* still no aad */ + htolem64(&lens[1], datalen); + poly1305_update(&poly1305_ctx, (void *)lens, sizeof(lens)); + + poly1305_finish(&poly1305_ctx, (uint8_t *)&mac); + + if (!vxnt_mac_eq(&vxntd->vxntd_mac, &mac)) { + explicit_bzero(chacha_ctx, sizeof(*chacha_ctx)); + return (-1); + } + + return (0); +} + +static struct mbuf * +vxnt_input(void *arg, struct mbuf *m, struct ip *ip, struct ip6_hdr *ip6, + void *uhp, int hlen) +{ + struct vxntep *vxntep = arg; + struct udphdr *uh; + struct vxnt_header *vxnt; + struct vxnt_header_data *vxntd; + int vxnthlen = hlen + sizeof(*vxnt); + + struct vxnt_softc *sct; + uint64_t nonce; + struct noise_remote *remote; + struct noise_keypair *kp; + int ret; + chacha_ctx chacha_ctx; + + union vxnt_addr faddr, laddr; + in_port_t fport, lport; + + if (m->m_pkthdr.len < vxnthlen) + goto drop; + + if (ip != NULL) { + memset(&faddr, 0, sizeof(faddr)); + memset(&laddr, 0, sizeof(laddr)); + laddr.in4 = ip->ip_dst; + faddr.in4 = ip->ip_src; + } else { + laddr.in6 = ip6->ip6_dst; + faddr.in6 = ip6->ip6_src; + } + + uh = uhp; + lport = uh->uh_dport; + fport = uh->uh_sport; + + if (m->m_len < vxnthlen) { + m = m_pullup(m, vxnthlen); + if (m == NULL) + return (NULL); + } + + vxnt = (struct vxnt_header *)(mtod(m, caddr_t) + hlen); + switch (vxnt->vxnt_type) { + case htonl(VXNT_DATA): + break; + case htonl(VXNT_INIT): + vxnt_handshake_input(vxntep, m, hlen, + sizeof(struct vxnt_header_init)); + return (NULL); + case htonl(VXNT_RESPONSE): + vxnt_handshake_input(vxntep, m, hlen, + sizeof(struct vxnt_header_response)); + return (NULL); + case htonl(VXNT_COOKIE): + vxnt_handshake_input(vxntep, m, hlen, + sizeof(struct vxnt_header_cookie)); + return (NULL); + default: + goto drop; + } + + vxnthlen = hlen + sizeof(*vxntd); + if (m->m_len < vxnthlen) { + m = m_pullup(m, vxnthlen); + if (m == NULL) { + /* counters_inc(vxnt_short_data_header); */ + return (NULL); + } + } + + vxntd = (struct vxnt_header_data *)(mtod(m, caddr_t) + hlen); + + smr_read_enter(); + sct = vxnt_find(vxntep, vxntd->vxntd_rindex); + if (sct == NULL) { + /* counters_inc(vxnt_invalid_index); */ + goto rele_drop; + } + + if (sct->sct_fport != 0 && sct->sct_fport != fport) { + /* sc->sc_wrong_fport++ */ + goto rele_drop; + } + /* XXX */ +#if 0 + if (!ISSET(sct->sct_flags, VNXT_IFT_FADDR_SET) && + !vxnt_addr_eq(&sct->sct_faddr, &faddr)) { + /* sc->sc_wrong_faddr++ */ + goto rele_drop; + } +#endif + + nonce = (uint64_t)bemtoh32(&vxntd->vxntd_seq_hi) << 32 | + (uint64_t)bemtoh32(&vxntd->vxntd_seq_lo); + + kp = noise_decrypt_begin(&sct->sct_remote, vxntd->vxntd_rindex, nonce); + if (kp == NULL) + goto rele_drop; + + if (vxnt_input_verify(&chacha_ctx, kp->kp_recv, nonce, + m, vxnthlen, vxntd) != 0) { + noise_decrypt_rollback(remote); + goto rele_drop; + } + + ret = noise_decrypt_commit(remote, kp, vxntd->vxntd_rindex, nonce); + switch (ret) { + case 0: + break; + case EINVAL: + explicit_bzero(&chacha_ctx, sizeof(chacha_ctx)); + goto rele_drop; + + case ECONNRESET: + /* XXX handshake complete */ + break; + case ESTALE: + /* XXX want init */ + break; + default: + panic("%s: noise_remote_commit returned unexpected %d", + sct->sct_if.if_xname, ret); + /* NOTREACHED */ + } + + /* pullup before we trim to try and keep some space for headers */ + m = m_pullup(m, m->m_pkthdr.len); + if (m == NULL) { + /* sc->sc_data_pullup++ */ + explicit_bzero(&chacha_ctx, sizeof(chacha_ctx)); + goto rele; + } + + m_adj(m, vxnthlen); + chacha_encrypt_bytes(&chacha_ctx, mtod(m, void *), mtod(m, void *), + m->m_pkthdr.len); + + vxn_input(sct, m); +rele: + smr_read_leave(); + return (NULL); + +rele_drop: + smr_read_leave(); +drop: + m_freem(m); + return (NULL); +} + +static void +vxn_input(struct vxnt_softc *sct, struct mbuf *m) +{ + struct vxn_header *vxnh; + struct vxn_key key; + struct vxn_softc *sc; + int hlen = sizeof(*vxnh); + int aoff; + + if (m->m_len < sizeof(*vxnh)) { + /* sct->sct_data_header_short++ */ + goto drop; + } + if ((vxnh->vxn_ver_prio & VXNT_VER_MASK) != VXNT_VER_0) { + /* sc->sc_data_version_wrong++ */ + goto drop; + } + if (ISSET(vxnh->vxn_opts, VXNT_F_OAM)) { + /* sc->sc_data_oam++ */ + goto drop; + } + + hlen += (vxnh->vxn_opts & VXNT_OPTLEN_MASK); + if (ISSET(vxnh->vxn_ver_prio, VXNT_PAD)) + hlen += 2; + + if (m->m_len < hlen) { + /* sc->sc_opt_short++ */ + goto drop; + } + + key.vk_proto = 0; + aoff = 0; + + switch (vxnh->vxn_proto) { + case htons(ETHERTYPE_TRANSETHER): + key.vk_proto = vxnh->vxn_proto; + aoff = sizeof(struct ether_header); + break; + case htons(ETHERTYPE_IP): + m->m_pkthdr.ph_family = AF_INET; + break; +#ifdef INET6 + case htons(ETHERTYPE_IPV6): + m->m_pkthdr.ph_family = AF_INET6; + break; +#endif +#ifdef MPLS + case htons(ETHERTYPE_MPLS): + m->m_pkthdr.ph_family = AF_MPLS; + break; +#endif + default: + /* sc->sc_noproto++ */ + goto drop; + } + + key.vk_vnetid = vxnh->vxn_vni; + + mtx_enter(&sct->sct_if_mtx); + sc = (struct vxn_softc *)RBT_FIND(vxn_keys, &sct->sct_if_keys, &key); + if (sc != NULL) + refcnt_take(&sc->sc_refs); + mtx_leave(&sct->sct_if_mtx); + + if (sc == NULL) { + /* sc->sc_noif++ */ + goto drop; + } + + switch (sc->sc_rxhprio) { + case IF_HDRPRIO_PACKET: + break; + case IF_HDRPRIO_OUTER: + m->m_pkthdr.pf.prio = vxnh->vxn_ver_prio & VXNT_PRIO_MASK; + break; + default: + m->m_pkthdr.pf.prio = sc->sc_rxhprio; + break; + } + + if (!ALIGNED_POINTER(mtod(m, caddr_t) + aoff, uint32_t)) { + uintptr_t diff = mtod(m, uintptr_t) + aoff; + diff %= sizeof(uint32_t); + + /* sc->sc_data_unaligned++ */ + + /* XXX this fiddles with the mbuf directly */ + m->m_data = memmove(m->m_data - diff, m->m_data, m->m_len); + } + + if_vinput(sc->sc_ifp, m); + vxn_rele(sc); + return; + +drop: + m_freem(m); +} + +static int +vxnt_ioctl(struct ifnet *ifp0, u_long cmd, caddr_t data) +{ + struct vxnt_softc *sct = ifp0->if_softc; + struct ifreq *ifr = (struct ifreq *)data; + int error = 0; + + switch (cmd) { + case SIOCSIFADDR: + return (ENXIO); + case SIOCSIFFLAGS: + if (ISSET(ifp0->if_flags, IFF_UP)) { + if (!ISSET(ifp0->if_flags, IFF_RUNNING)) + error = vxnt_up(sct); + else + error = 0; + } else { + if (ISSET(ifp0->if_flags, IFF_RUNNING)) + error = vxnt_down(sct); + } + break; + + case SIOCSIFMTU: + error = EOPNOTSUPP; + break; + + case SIOCSLIFPHYRTABLE: + error = vxnt_set_rdomain(sct, ifr); + break; + case SIOCGLIFPHYRTABLE: + error = vxnt_get_rdomain(sct, ifr); + break; + + case SIOCSLIFPHYADDR: + error = vxnt_set_tunnel(sct, (const struct if_laddrreq *)data); + break; + case SIOCGLIFPHYADDR: + error = vxnt_get_tunnel(sct, (struct if_laddrreq *)data); + break; + case SIOCDIFPHYADDR: + error = vxnt_del_tunnel(sct); + break; + + case SIOCSLIFPHYDF: + /* commit */ + sct->sct_df = ifr->ifr_df ? htons(IP_DF) : htons(0); + break; + case SIOCGLIFPHYDF: + ifr->ifr_df = sct->sct_df ? 1 : 0; + break; + + case SIOCSLIFPHYTTL: + if (ifr->ifr_ttl < 1 || ifr->ifr_ttl > 255) { + error = EINVAL; + break; + } + + /* commit */ + sct->sct_ttl = (uint8_t)ifr->ifr_ttl; + break; + case SIOCGLIFPHYTTL: + ifr->ifr_ttl = (int)sct->sct_ttl; + break; + + case SIOCADDMULTI: + case SIOCDELMULTI: + break; + + default: + error = ENOTTY; + break; + } + + return (error); +} + +static int +vxn_ioctl(struct ifnet *ifp, u_long cmd, caddr_t data) +{ + struct vxn_softc *sc = ifp->if_softc; + struct ifreq *ifr = (struct ifreq *)data; + int error = 0; + + switch (cmd) { + case SIOCSIFADDR: + error = 0; + break; + case SIOCSIFFLAGS: + if (ISSET(ifp->if_flags, IFF_UP)) { + if (!ISSET(ifp->if_flags, IFF_RUNNING)) + error = vxn_up(sc); + else + error = 0; + } else { + if (ISSET(ifp->if_flags, IFF_RUNNING)) + error = vxn_down(sc); + } + break; + + case SIOCSIFPARENT: + error = vxn_set_parent(sc, (const struct if_parent *)data); + break; + case SIOCGIFPARENT: + error = vxn_get_parent(sc, (struct if_parent *)data); + break; + case SIOCDIFPARENT: + error = vxn_del_parent(sc); + break; + + case SIOCSVNETID: + error = vxn_set_vnetid(sc, ifr); + break; + case SIOCGVNETID: + error = vxn_get_vnetid(sc, ifr); + break; + + case SIOCSTXHPRIO: + error = if_txhprio_l3_check(ifr->ifr_hdrprio); + if (error != 0) + break; + + sc->sc_txhprio = ifr->ifr_hdrprio; + break; + case SIOCGTXHPRIO: + ifr->ifr_hdrprio = sc->sc_txhprio; + break; + + case SIOCSRXHPRIO: + error = if_rxhprio_l3_check(ifr->ifr_hdrprio); + if (error != 0) + break; + + sc->sc_rxhprio = ifr->ifr_hdrprio; + break; + case SIOCGRXHPRIO: + ifr->ifr_hdrprio = sc->sc_rxhprio; + break; + + case SIOCADDMULTI: + case SIOCDELMULTI: + break; + + default: + error = ENOTTY; + break; + } + + return (error); +} + +static int +evxn_ioctl(struct ifnet *ifp, u_long cmd, caddr_t data) +{ + struct vxn_softc *sc = ifp->if_softc; + struct ifreq *ifr = (struct ifreq *)data; + int error = 0; + + switch (cmd) { + case SIOCSIFADDR: + error = 0; + break; + case SIOCSIFFLAGS: + if (ISSET(ifp->if_flags, IFF_UP)) { + if (!ISSET(ifp->if_flags, IFF_RUNNING)) + error = vxn_up(sc); + else + error = 0; + } else { + if (ISSET(ifp->if_flags, IFF_RUNNING)) + error = vxn_down(sc); + } + break; + + case SIOCSIFPARENT: + error = vxn_set_parent(sc, (const struct if_parent *)data); + break; + case SIOCGIFPARENT: + error = vxn_get_parent(sc, (struct if_parent *)data); + break; + case SIOCDIFPARENT: + error = vxn_del_parent(sc); + break; + + case SIOCSVNETID: + error = vxn_set_vnetid(sc, ifr); + break; + case SIOCGVNETID: + error = vxn_get_vnetid(sc, ifr); + break; + + case SIOCSTXHPRIO: + error = if_txhprio_l2_check(ifr->ifr_hdrprio); + if (error != 0) + break; + + sc->sc_txhprio = ifr->ifr_hdrprio; + break; + case SIOCGTXHPRIO: + ifr->ifr_hdrprio = sc->sc_txhprio; + break; + + case SIOCSRXHPRIO: + error = if_rxhprio_l2_check(ifr->ifr_hdrprio); + if (error != 0) + break; + + sc->sc_rxhprio = ifr->ifr_hdrprio; + break; + case SIOCGRXHPRIO: + ifr->ifr_hdrprio = sc->sc_rxhprio; + break; + + case SIOCADDMULTI: + case SIOCDELMULTI: + break; + + default: + error = ether_ioctl(ifp, (struct arpcom *)ifp, cmd, data); + break; + } + + if (error == ENETRESET) + error = 0; + + return (error); +} + +static int +vxn_up(struct vxn_softc *sc) +{ + struct ifnet *ifp = sc->sc_ifp; + struct ifnet *ifp0; + struct vxnt_softc *sct; + struct vxn_key *ovk; + int error = 0; + + ifp0 = if_get(sc->sc_if_index0); + if (ifp == NULL) + return (ENXIO); + + /* check vxn/evxn will work on top of the parent */ + if (ifp0->if_type != VXNT_IFTYPE) { + error = EPROTONOSUPPORT; + goto put; + } + + sct = ifp0->if_softc; + + /* commit the sc */ + mtx_enter(&sct->sct_if_mtx); + ovk = RBT_INSERT(vxn_keys, &sct->sct_if_keys, &sc->sc_key); + mtx_leave(&sct->sct_if_mtx); + + if (ovk != NULL) { + error = EADDRINUSE; + goto put; + } + + /* Register callback for physical link state changes */ + if_linkstatehook_add(ifp0, &sc->sc_lhook); + + /* Register callback if parent wants to unregister */ + if_detachhook_add(ifp0, &sc->sc_dhook); + + /* we're running now */ + SET(ifp->if_flags, IFF_RUNNING); + vxn_link_hook(sc); + + if_put(ifp0); +put: + return (error); +} + +static int +vxn_down(struct vxn_softc *sc) +{ + struct ifnet *ifp = sc->sc_ifp; + struct ifnet *ifp0; + struct vxnt_softc *sct; + + CLR(ifp->if_flags, IFF_RUNNING); + + ifp0 = if_get(sc->sc_if_index0); + if (ifp0 == NULL) { + sc->sc_if_index0 = 0; + return (0); + } + + sct = ifp0->if_softc; + + if_detachhook_del(ifp0, &sc->sc_dhook); + if_detachhook_del(ifp0, &sc->sc_lhook); + + mtx_enter(&sct->sct_if_mtx); + RBT_REMOVE(vxn_keys, &sct->sct_if_keys, &sc->sc_key); + mtx_leave(&sct->sct_if_mtx); + + if_put(ifp0); + + return (0); +} + +static int +vxnt_output(struct ifnet *ifp0, struct mbuf *m, struct sockaddr *dst, + struct rtentry *rt) +{ + m_freem(m); + return (EOPNOTSUPP); +} + +static int +vxn_output(struct ifnet *ifp, struct mbuf *m, struct sockaddr *dst, + struct rtentry *rt) +{ + int error; + + if (!ISSET(ifp->if_flags, IFF_RUNNING)) { + error = ENETDOWN; + goto drop; + } + + switch (dst->sa_family) { + case AF_INET: +#ifdef INET6 + case AF_INET6: +#endif +#ifdef MPLS + case AF_MPLS: +#endif + break; + default: + error = EAFNOSUPPORT; + goto drop; + } + + m->m_pkthdr.ph_family = dst->sa_family; + + error = if_enqueue(ifp, m); + return (error); + +drop: + m_freem(m); + return (error); +} + +static int +vxnt_enqueue(struct ifnet *ifp0, struct mbuf *m) +{ + int error = 0; + struct m_tag *mtag = NULL; + unsigned int if_index; + struct ifqueue *ifq; + + /* Try to limit infinite recursion through misconfiguration. */ + while ((mtag = m_tag_find(m, PACKET_TAG_GRE, mtag)) != NULL) { + if_index = *(unsigned int *)(mtag + 1); + if (ifp0->if_index == if_index) { + error = EIO; + goto drop; + } + } + + mtag = m_tag_get(PACKET_TAG_GRE, sizeof(if_index), M_NOWAIT); + if (mtag == NULL) { + error = ENOBUFS; + goto drop; + } + *(unsigned int *)(mtag + 1) = ifp0->if_index; + m_tag_prepend(m, mtag); + + /* push the packet onto the ifq separately to running the ifq */ + ifq = &ifp0->if_snd; + error = ifq_enqueue(ifq, m); + if (error) + return (error); + + /* always defer running the ifq to the taskq */ + task_add(ifq->ifq_softnet, &ifq->ifq_bundle); + return (0); + +drop: + m_freem(m); + return (error); +} + +static int +vxn_enqueue(struct ifnet *ifp, struct mbuf *m) +{ + struct ifnet *ifp0; + struct vxn_softc *sc; + int error = 0; + + if (!ifq_is_priq(&ifp->if_snd)) + return (if_enqueue_ifq(ifp, m)); + + sc = ifp->if_softc; + ifp0 = if_get(sc->sc_if_index0); + + if (ifp0 == NULL || !ISSET(ifp0->if_flags, IFF_RUNNING)) { + m_freem(m); + error = ENETDOWN; + } else { + counters_pkt(ifp->if_counters, + ifc_opackets, ifc_obytes, m->m_pkthdr.len); + vxn_start_transmit(sc, ifp0, m); + } + + if_put(ifp0); + + return (error); +} + +static void +vxnt_start(struct ifqueue *ifq) +{ + struct ifnet *ifp0 = ifq->ifq_if; + struct vxnt_softc *sct = ifp0->if_softc; + struct noise_remote *r = &sct->sct_remote; + struct noise_keypair *kp; + struct mbuf *m; + uint32_t index; + uint64_t nonce; + struct mbuf_list ml = MBUF_LIST_INITIALIZER(); + int stale; + + /* + * XXX this is a lot of noise_remote_encrypt. updates to the + * send counter are serialised by the ifq. + */ + + rw_enter_read(&r->r_keypair_lock); + kp = r->r_current; + if (kp == NULL || !kp->kp_valid || + noise_timer_expired(&kp->kp_birthdate, REJECT_AFTER_TIME, 0) || + kp->kp_ctr.c_recv >= REJECT_AFTER_MESSAGES || + (nonce = kp->kp_ctr.c_send) >= REJECT_AFTER_MESSAGES) { + rw_exit_read(&r->r_keypair_lock); + ifq_purge(ifq); + return; + } + + index = kp->kp_remote_index; + + while ((m = ifq_dequeue(ifq)) != NULL) { + m = vxnt_encrypt(sct, kp, m, nonce); + if (m == NULL) + continue; + + ml_enqueue(&ml, m); + ++nonce; + } + + stale = (nonce >= REKEY_AFTER_MESSAGES) || + (kp->kp_is_initiator && + noise_timer_expired(&kp->kp_birthdate, REKEY_AFTER_TIME, 0)); + + kp->kp_ctr.c_send = nonce; + rw_exit_read(&r->r_keypair_lock); + + if (stale) + wg_timers_event_want_initiation(&peer->p_timers); + + vxnt_send(sct, &ml); +} + +static uint64_t +vxnt_send_ipv4(struct mbuf *m) +{ + int rv; + + rv = ip_output(m, NULL, NULL, IP_RAWOUTPUT, NULL, NULL, 0); + + return (rv != 0); +} + +#ifdef INET6 +static uint64_t +vxnt_send_ipv6(struct mbuf *m) +{ + int rv; + + rv = ip6_output(m, NULL, NULL, 0, NULL, NULL); + + return (rv != 0); +} +#endif /* INET6 */ + +static void +vxnt_send(struct vxnt_softc *sct, struct mbuf_list *ml) +{ + struct mbuf *(*encap)(struct vxnt_softc *, struct mbuf *, + const union vxnt_addr *, const union vxnt_addr *, uint8_t); + uint64_t (*send)(struct mbuf *); + struct mbuf *m; + struct udphdr *uh; + uint64_t oerrors = 0; + + switch (sct->sct_af) { + case AF_INET: + encap = vxnt_encap_ipv4; + send = vxnt_send_ipv4; + break; +#ifdef INET6 + case AF_INET6: + encap = vxnt_encap_ipv6; + send = vxnt_send_ipv6; + break; +#endif + default: + unhandled_af(sct->sct_af); + /* NOTREACHED */ + } + + NET_LOCK(); + while ((m = ml_dequeue(ml)) != NULL) { + m = m_prepend(m, sizeof(*uh), M_DONTWAIT); + if (m == NULL) { + oerrors++; + continue; + } + + uh = mtod(m, struct udphdr *); + uh->uh_sport = sct->sct_sport; + uh->uh_dport = sct->sct_dport; + htobem16(&uh->uh_ulen, m->m_pkthdr.len); + uh->uh_sum = 0; + + m = (*encap)(sct, m, &sct->sct_saddr, &sct->sct_daddr, 0); + if (m == NULL) + oerrors++; + + CLR(m->m_flags, M_BCAST|M_MCAST); + m->m_pkthdr.ph_rtableid = sct->sct_rdomain; +#if NPF > 0 + pf_pkt_addr_changed(m); +#endif + + oerrors += (*send)(m); + } + NET_UNLOCK(); + + counters_add(sct->sct_if.if_counters, ifc_oerrors, oerrors); +} + +static struct mbuf * +vxnt_encrypt(struct vxnt_softc *sc, struct noise_keypair *kp, struct mbuf *m, + uint64_t nonce) +{ + struct vxnt_header_data *vxntd; + int datalen = m->m_pkthdr.len; + void *data; + + chacha_ctx chacha_ctx; + poly1305_state poly1305_ctx; + uint8_t block0[CHACHA20POLY1305_KEY_SIZE]; + uint64_t lens[2]; + + m = m_prepend(m, sizeof(*vxntd), M_DONTWAIT); + if (m == NULL) + return (NULL); + + /* chacha wants contig memory */ + m = m_pullup(m, m->m_pkthdr.len); + if (m == NULL) + return (NULL); + + vxntd = mtod(m, struct vxnt_header_data *); + vxntd->vxntd_type = htonl(VXNT_DATA); + vxntd->vxntd_rindex = kp->kp_remote_index; + htobem32(&vxntd->vxntd_seq_hi, nonce >> 32); + htobem32(&vxntd->vxntd_seq_lo, nonce); + + data = (vxntd + 1); + + chacha_keysetup(&chacha_ctx, kp->kp_send, + CHACHA20POLY1305_KEY_SIZE * 8); + + /* + * AEAD-ChaCha20-Poly1305 uses the IETF construction with the 96bit + * nonce and 32 bit counter starting from 0. Noise uses this AEAD + * with the counter off the wire as the nonce. That nonce is padded + * with zeros on the left, so the layout of the chacha state is + * predictable. Set it up directly, rather than stuff bits just so + * they can be unstuffed straight away. + */ + chacha_ctx.input[12] = 0; /* counter starts at 0 */ + chacha_ctx.input[13] = 0; /* nonce padding is 0 */ + chacha_ctx.input[14] = nonce; /* the rest of the "nonce" */ + chacha_ctx.input[15] = nonce >> 32; + + /* AEAD-ChaCha20-Poly1305 uses the first block for the poly key */ + bzero(block0, sizeof(block0)); + chacha_encrypt_bytes(&chacha_ctx, block0, block0, sizeof(block0)); + poly1305_init(&poly1305_ctx, block0); + explicit_bzero(block0, sizeof(block0)); + + /* there's no aad */ + + /* feed the data in */ + chacha_encrypt_bytes(&chacha_ctx, data, data, datalen); + explicit_bzero(&chacha_ctx, sizeof(chacha_ctx)); + poly1305_update(&poly1305_ctx, data, datalen); + + /* block0 is full of zeros now, so we can use it for padding */ + poly1305_update(&poly1305_ctx, block0, (0x10 - datalen) & 0xf); + + lens[0] = htole64(0); /* still no aad */ + htolem64(&lens[1], datalen); + poly1305_update(&poly1305_ctx, (void *)lens, sizeof(lens)); + poly1305_finish(&poly1305_ctx, vxntd->vxntd_mac); + explicit_bzero(&poly1305_ctx, sizeof(poly1305_ctx)); + + return (m); +} + +static void +vxn_start(struct ifqueue *ifq) +{ + struct ifnet *ifp = ifq->ifq_if; + struct vxn_softc *sc = ifp->if_softc; + struct ifnet *ifp0; + struct mbuf *m; + + ifp0 = if_get(sc->sc_if_index0); + if (!vxn_if_running(ifp0)) { + ifq_purge(ifq); + goto leave; + } + + while ((m = ifq_dequeue(ifq)) != NULL) + vxn_start_transmit(sc, ifp0, m); + +leave: + if_put(ifp0); +} + +static void +vxn_start_transmit(struct vxn_softc *sc, struct ifnet *ifp0, struct mbuf *m) +{ +#if NBPFILTER > 0 + caddr_t if_bpf = READ_ONCE(ifp0->if_bpf); + + if (if_bpf) + (*ifp0->if_bpf_mtap)(if_bpf, m, BPF_DIRECTION_OUT); +#endif /* NBPFILTER > 0 */ + + (*sc->sc_transmit)(sc, ifp0, m); +} + +static void +vxn_transmit(struct vxn_softc *sc, struct ifnet *ifp0, struct mbuf *m) +{ + uint16_t proto; + uint8_t flags, tos; + + switch (m->m_pkthdr.ph_family) { + case AF_INET: { + struct ip *ip; + + m = m_pullup(m, sizeof(*ip)); + if (m == NULL) + return; + + ip = mtod(m, struct ip *); + tos = ip->ip_tos; + + proto = htons(ETHERTYPE_IP); + break; + } +#ifdef INET6 + case AF_INET6: { + struct ip6_hdr *ip6; + + m = m_pullup(m, sizeof(*ip6)); + if (m == NULL) + return; + + ip6 = mtod(m, struct ip6_hdr *); + tos = (bemtoh32(&ip6->ip6_flow) & 0x0ff00000) >> 20; + + proto = htons(ETHERTYPE_IPV6); + break; + } +#endif +#ifdef MPLS + case AF_MPLS: { + uint32_t shim; + + m = m_pullup(m, sizeof(shim)); + if (m == NULL) + return; + + shim = bemtoh32(mtod(m, uint32_t *)) & MPLS_EXP_MASK; + tos = IFQ_PRIO2TOS(shim >> MPLS_EXP_OFFSET); + + proto = htons(ETHERTYPE_MPLS); + break; + } +#endif + default: + unhandled_af(m->m_pkthdr.ph_family); + } + + switch (sc->sc_txhprio) { + case IF_HDRPRIO_PAYLOAD: + flags = IFQ_TOS2PRIO(tos); + break; + case IF_HDRPRIO_PACKET: + flags = m->m_pkthdr.pf.prio; + break; + default: + flags = sc->sc_txhprio; + break; + } + + vxn_encap(sc, ifp0, m, flags, proto); +} + +static void +evxn_transmit(struct vxn_softc *sc, struct ifnet *ifp0, struct mbuf *m) +{ + uint8_t flags; + uint16_t *pad; + + m = m_prepend(m, sizeof(*pad), M_DONTWAIT); + if (m == NULL) + return; + + switch (sc->sc_txhprio) { + case IF_HDRPRIO_PACKET: + flags = m->m_pkthdr.pf.prio; + break; + default: + flags = sc->sc_txhprio; + break; + } + + flags |= VXNT_PAD; + + vxn_encap(sc, ifp0, m, flags, htons(ETHERTYPE_TRANSETHER)); +} + +static void +vxn_encap(struct vxn_softc *sc, struct ifnet *ifp0, struct mbuf *m, + uint8_t flags, uint16_t proto) +{ + struct vxn_header *vh; + + m = m_prepend(m, sizeof(*vh), M_DONTWAIT); + if (m == NULL) + return; + + vh = mtod(m, struct vxn_header *); + vh->vxn_ver_prio = VXNT_VER_0 | flags; + vh->vxn_opts = 0; + vh->vxn_proto = proto; + vh->vxn_vni = sc->sc_key.vk_vnetid; + + if_enqueue(ifp0, m); +} + +static struct vxntep * +vxntep_get(struct vxnt_softc *sct) +{ + struct vxntep *vxntep; + + TAILQ_FOREACH(vxntep, &vxnteps, vxntep_entry) { + if (sct->sct_af == vxntep->vxntep_af && + sct->sct_rdomain == vxntep->vxntep_rdomain && + vxnt_addr_eq(&sct->sct_laddr, &vxntep->vxntep_addr) && + (sct->sct_lport == htons(0) || + sct->sct_lport == vxntep->vxntep_port)) + return (vxntep); + } + + return (NULL); +} + +static int +geneve_tep_add_addr(struct geneve_softc *sc) +{ + struct mbuf m; + struct geneve_tep *gtep; + struct socket *so; + struct sockaddr_in *sin; +#ifdef INET6 + struct sockaddr_in6 *sin6; +#endif + int error; + int s; + + gtep = geneve_tep_get(sc, addr); + if (gtep != NULL) { + struct geneve_peer *op; + + mtx_enter(>ep->gtep_mtx); + op = RBT_INSERT(geneve_peers, >ep->gtep_peers, p); + mtx_leave(>ep->gtep_mtx); + + if (op != NULL) + return (EADDRINUSE); + + return (0); + } + + vxntep = malloc(sizeof(*vxntep), M_DEVBUF, M_NOWAIT|M_ZERO); + if (vxntep == NULL) + return (ENOMEM); + + vxntep->vxntep_af = sc->sc_af; + vxntep->vxntep_rdomain = sc->sc_rdomain; + vxntep->vxntep_addr = sc->sc_laddr; + vxntep->vxntep_port = sc->sc_lport; + + mtx_init(&vxntep->vxntep_mtx, IPL_SOFTNET); + RBT_INIT(vxntep_peers, &vxntep->vxntep_peers); + RBT_INSERT(vxntep_peers, &vxntep->vxntep_peers, p); + + error = socreate(vxntep->vxntep_af, &so, SOCK_DGRAM, IPPROTO_UDP); + if (error != 0) + goto free; + + s = solock(so); + + sotoinpcb(so)->inp_upcall = vxnt_input; + sotoinpcb(so)->inp_upcall_arg = vxntep; + + m_inithdr(&m); + m.m_len = sizeof(vxntep->vxntep_rdomain); + *mtod(&m, unsigned int *) = vxntep->vxntep_rdomain; + error = sosetopt(so, SOL_SOCKET, SO_RTABLE, &m); + if (error != 0) + goto close; + + m_inithdr(&m); + switch (gtep->gtep_af) { + case AF_INET: + sin = mtod(&m, struct sockaddr_in *); + memset(sin, 0, sizeof(*sin)); + sin->sin_len = sizeof(*sin); + sin->sin_family = AF_INET; + sin->sin_addr = addr->in4; + sin->sin_port = gtep->gtep_port; + + m.m_len = sizeof(*sin); + break; + +#ifdef INET6 + case AF_INET6: + sin6 = mtod(&m, struct sockaddr_in6 *); + sin6->sin6_len = sizeof(*sin6); + sin6->sin6_family = AF_INET6; + in6_recoverscope(sin6, &addr->in6); + sin6->sin6_port = sc->sc_port; + + m.m_len = sizeof(*sin6); + break; +#endif + default: + unhandled_af(gtep->gtep_af); + } + + error = sobind(so, &m, curproc); + if (error != 0) + goto close; + + sounlock(so, s); + + rw_assert_wrlock(&geneve_lock); + TAILQ_INSERT_TAIL(&geneve_teps, gtep, gtep_entry); + + gtep->gtep_so = so; + + return (0); + +close: + sounlock(so, s); + soclose(so, MSG_DONTWAIT); +free: + free(gtep, M_DEVBUF, sizeof(*gtep)); + return (error); +} + +static void +geneve_tep_del_addr(struct geneve_softc *sc, const union geneve_addr *addr, + struct geneve_peer *p) +{ + struct geneve_tep *gtep; + int empty; + + gtep = geneve_tep_get(sc, addr); + if (gtep == NULL) + panic("unable to find geneve_tep for peer %p (sc %p)", p, sc); + + mtx_enter(>ep->gtep_mtx); + RBT_REMOVE(geneve_peers, >ep->gtep_peers, p); + empty = RBT_EMPTY(geneve_peers, >ep->gtep_peers); + mtx_leave(>ep->gtep_mtx); + + if (!empty) + return; + + rw_assert_wrlock(&geneve_lock); + TAILQ_REMOVE(&geneve_teps, gtep, gtep_entry); + + soclose(gtep->gtep_so, MSG_DONTWAIT); + free(gtep, M_DEVBUF, sizeof(*gtep)); +} + +static int +geneve_tep_up(struct geneve_softc *sc) +{ + struct geneve_peer *up, *mp; + int error; + + up = malloc(sizeof(*up), M_DEVBUF, M_NOWAIT|M_ZERO); + if (up == NULL) + return (ENOMEM); + + up->p_mask = (sc->sc_mode != GENEVE_TMODE_P2P); + up->p_addr = sc->sc_dst; + up->p_header = sc->sc_header; + up->p_if_index = sc->sc_ac.ac_if.if_index; + + error = geneve_tep_add_addr(sc, &sc->sc_src, up); + if (error != 0) + goto freeup; + + sc->sc_ucast_peer = up; + + if (sc->sc_mode != GENEVE_TMODE_LEARNING) + return (0); + + mp = malloc(sizeof(*mp), M_DEVBUF, M_NOWAIT|M_ZERO); + if (mp == NULL) { + error = ENOMEM; + goto delup; + } + + mp->p_mask = 1; + /* addr is masked, leave it as 0s */ + mp->p_header = sc->sc_header; + mp->p_if_index = sc->sc_ac.ac_if.if_index; + + /* destination address is a multicast group we want to join */ + error = geneve_tep_add_addr(sc, &sc->sc_dst, up); + if (error != 0) + goto freemp; + + sc->sc_mcast_peer = mp; + + return (0); + +freemp: + free(mp, M_DEVBUF, sizeof(*mp)); +delup: + geneve_tep_del_addr(sc, &sc->sc_src, up); +freeup: + free(up, M_DEVBUF, sizeof(*up)); + return (error); +} + +static void +vxntep_down(struct vxnt_softc *sc) +{ + struct vxnt_peer *p = sc->sc_ucast_peer; + + geneve_tep_del_addr(sc, &sc->sc_src, up); + free(up, M_DEVBUF, sizeof(*up)); +} + +static int +vxnt_up(struct vxnt_softc *sc) +{ + struct ifnet *ifp = &sc->sc_if; + int error; + + KASSERT(!ISSET(ifp->if_flags, IFF_RUNNING)); + NET_ASSERT_LOCKED(); + + if (sc->sc_af == AF_UNSPEC) + return (EDESTADDRREQ); + + NET_UNLOCK(); + + error = rw_enter(&vxnt_lock, RW_WRITE|RW_INTR); + if (error != 0) + goto netlock; + + NET_LOCK(); + if (ISSET(ifp->if_flags, IFF_RUNNING)) { + /* something else beat us */ + rw_exit(&vxnt_lock); + return (0); + } + NET_UNLOCK(); + + error = vxnt_tep_up(sc); + if (error != 0) + goto unlock; + + NET_LOCK(); + SET(ifp->if_flags, IFF_RUNNING); + rw_exit(&vxnt_lock); + + return (0); + +unlock: + rw_exit(&geneve_lock); +netlock: + NET_LOCK(); + + return (error); +} + +static int +vxnt_down(struct vxnt_softc *sc) +{ + struct ifnet *ifp = &sc->sc_if; + struct ifnet *ifp0; + int error; + + KASSERT(ISSET(ifp->if_flags, IFF_RUNNING)); + NET_UNLOCK(); + + error = rw_enter(&vxnt_lock, RW_WRITE|RW_INTR); + if (error != 0) { + NET_LOCK(); + return (error); + } + + NET_LOCK(); + if (!ISSET(ifp->if_flags, IFF_RUNNING)) { + /* something else beat us */ + rw_exit(&vxnt_lock); + return (0); + } + NET_UNLOCK(); + + vxnt_tep_down(sc); + + taskq_del_barrier(ifp->if_snd.ifq_softnet, &sc->sc_send_task); + NET_LOCK(); + CLR(ifp->if_flags, IFF_RUNNING); + rw_exit(&vxnt_lock); + + return (0); +} + +static int +vxnt_set_rdomain(struct vxnt_softc *sc, const struct ifreq *ifr) +{ + struct ifnet *ifp = &sc->sc_ac.ac_if; + + if (ifr->ifr_rdomainid < 0 || + ifr->ifr_rdomainid > RT_TABLEID_MAX) + return (EINVAL); + if (!rtable_exists(ifr->ifr_rdomainid)) + return (EADDRNOTAVAIL); + + if (sc->sc_rdomain == ifr->ifr_rdomainid) + return (0); + + if (!ISSET(ifp->if_flags, IFF_RUNNING)) + return (EBUSY); + + /* commit */ + sc->sc_rdomain = ifr->ifr_rdomainid; + + return (0); +} + +static int +vxnt_get_rdomain(struct vxnt_softc *sc, struct ifreq *ifr) +{ + ifr->ifr_rdomainid = sc->sc_rdomain; + + return (0); +} + +static int +vxnt_set_tunnel(struct vxnt_softc *sc, const struct if_laddrreq *req) +{ + struct ifnet *ifp = &sc->sc_if; + struct sockaddr *src = (struct sockaddr *)&req->addr; + struct sockaddr *dst = (struct sockaddr *)&req->dstaddr; + struct sockaddr_in *src4, *dst4; +#ifdef INET6 + struct sockaddr_in6 *src6, *dst6; + int error; +#endif + union vxnt_addr saddr, daddr; + in_port_t sport = htons(0); + in_port_t dport = htons(0); + unsigned int mode = VXNT_MODE_SPOKE; + + if (dst->sa_family == AF_UNSPEC) + mode = VXNT_MODE_HUB; + else if (src->sa_family != dst->sa_family) + return (EINVAL); + + memset(&saddr, 0, sizeof(saddr)); + memset(&daddr, 0, sizeof(daddr)); + + /* validate */ + switch (src->sa_family) { + case AF_INET: + src4 = (struct sockaddr_in *)src; + if (IN_MULTICAST(src4->sin_addr.s_addr)) + return (EINVAL); + + saddr.in4 = src4->sin_addr; + sport = src4->sin_port; + + if (mode == VXNT_MODE_HUB) + break; + + dst4 = (struct sockaddr_in *)dst; + if (in_nullhost(dst4->sin_addr) || + IN_MULTICAST(dst4->sin_addr.s_addr)) + return (EINVAL); + + daddr.in4 = src4->sin_addr; + dport = src4->sin_port; + break; + +#ifdef INET6 + case AF_INET6: + src6 = (struct sockaddr_in6 *)src; + if (IN6_IS_ADDR_MULTICAST(&src6->sin6_addr)) + return (EINVAL); + + error = in6_embedscope(&saddr.in6, src6, NULL); + if (error != 0) + return (error); + sport = src6->sin6_port; + + if (mode == VXNT_MODE_HUB) + break; + + dst6 = (struct sockaddr_in6 *)dst; + if (IN6_IS_ADDR_UNSPECIFIED(&dst6->sin6_addr) || + IN6_IS_ADDR_MULTICAST(&dst6->sin6_addr)) + return (EINVAL); + + if (src6->sin6_scope_id != dst6->sin6_scope_id) + return (EINVAL); + + error = in6_embedscope(&daddr.in6, dst6, NULL); + if (error != 0) + return (error); + dport = dst6->sin6_port; + break; +#endif + default: + return (EAFNOSUPPORT); + } + + if (mode == VXNT_MODE_HUB) { + if (sport == htons(0)) + return (EADDRNOTAVAIL); + } else { /* VXNT_MODE_SPOKE */ + if (dport == htons(0)) + return (EDESTADDRREQ); + } + + if (sc->sc_mode == mode && + vxnt_addr_eq(&sc->sc_saddr, &saddr) == 0 && + vxnt_addr_eq(&sc->sc_daddr, &daddr) == 0 && + sc->sc_sport == sport && sc->sc_dport == dport) + return (0); + + if (ISSET(ifp->if_flags, IFF_RUNNING)) + return (EBUSY); + + /* commit */ + sc->sc_af = src->sa_family; + sc->sc_saddr = saddr; + sc->sc_daddr = daddr; + sc->sc_sport = sport; + sc->sc_dport = dport; + + return (0); +} + +static int +vxnt_get_tunnel(struct vxnt_softc *sc, struct if_laddrreq *req) +{ + struct sockaddr *dstaddr = (struct sockaddr *)&req->dstaddr; + struct sockaddr_in *sin; +#ifdef INET6 + struct sockaddr_in6 *sin6; +#endif + + if (sc->sc_af == AF_UNSPEC) + return (EADDRNOTAVAIL); + + memset(&req->addr, 0, sizeof(req->addr)); + memset(&req->dstaddr, 0, sizeof(req->dstaddr)); + + /* default to endpoint */ + dstaddr->sa_len = 2; + dstaddr->sa_family = AF_UNSPEC; + + switch (sc->sc_af) { + case AF_INET: + sin = (struct sockaddr_in *)&req->addr; + sin->sin_len = sizeof(*sin); + sin->sin_family = AF_INET; + sin->sin_addr = sc->sc_saddr.in4; + sin->sin_port = sc->sc_sport; + + sin = (struct sockaddr_in *)&req->dstaddr; + sin->sin_len = sizeof(*sin); + sin->sin_family = AF_INET; + sin->sin_addr = sc->sc_daddr.in4; + sin->sin_port = sc->sc_dport; + break; + +#ifdef INET6 + case AF_INET6: + sin6 = (struct sockaddr_in6 *)&req->addr; + sin6->sin6_len = sizeof(*sin6); + sin6->sin6_family = AF_INET6; + in6_recoverscope(sin6, &sc->sc_saddr.in6); + sin6->sin6_port = sc->sc_sport; + + sin6 = (struct sockaddr_in6 *)&req->dstaddr; + sin6->sin6_len = sizeof(*sin6); + sin6->sin6_family = AF_INET6; + in6_recoverscope(sin6, &sc->sc_daddr.in6); + sin6->sin6_port = sc->sc_dport; + break; +#endif + default: + unhandled_af(sc->sc_af); + } + + return (0); +} + +static int +vxnt_del_tunnel(struct vxnt_softc *sc) +{ + struct ifnet *ifp = &sc->sc_ac.ac_if; + + if (sc->sc_af == AF_UNSPEC) + return (0); + + if (ISSET(ifp->if_flags, IFF_RUNNING)) + return (EBUSY); + + /* commit */ + sc->sc_af = AF_UNSPEC; + memset(&sc->sc_saddr, 0, sizeof(sc->sc_saddr)); + memset(&sc->sc_daddr, 0, sizeof(sc->sc_daddr)); + sc->sc_sport = htons(0); + sc->sc_dport = htons(0); + + return (0); +} + +static int +vxn_set_vnetid(struct vxn_softc *sc, const struct ifreq *ifr) +{ + struct ifnet *ifp = sc->sc_if; + uint32_t vnetid; + + if (ifr->ifr_vnetid < 0x00000000 || + ifr->ifr_vnetid > 0xffffffff) + return (EINVAL); + + vnetid = htonl(ifr->ifr_vnetid); + if (sc->sc_key.vk_vnetid == vnetid) + return (0); + + if (ISSET(ifp->if_flags, IFF_RUNNING)) + return (EBUSY); + + /* commit */ + sc->sc_key.vk_vnetid = vnetid; + + return (0); +} + +static int +vxn_get_vnetid(struct vxn_softc *sc, struct ifreq *ifr) +{ + ifr->ifr_vnetid = ntohl(sc->sc_key.vk_vnetid); + + return (0); +} + +static int +vxn_set_parent(struct vxn_softc *sc, const struct if_parent *p) +{ + struct ifnet *ifp = sc->sc_ifp; + struct ifnet *ifp0; + int error = 0; + + ifp0 = if_unit(p->ifp_parent); + if (ifp0 == NULL) + return (ENXIO); + + if (ifp0->if_type != VXNT_IFTYPE) { + error = ENXIO; + goto put; + } + + if (sc->sc_if_index0 == ifp0->if_index) + goto put; + + if (ISSET(ifp->if_flags, IFF_RUNNING)) { + error = EBUSY; + goto put; + } + + /* commit */ + sc->sc_if_index0 = ifp0->if_index; + +put: + if_put(ifp0); + return (error); +} + +static int +vxn_get_parent(struct vxn_softc *sc, struct if_parent *p) +{ + struct ifnet *ifp0; + int error = 0; + + ifp0 = if_get(sc->sc_if_index0); + if (ifp0 == NULL) + error = EADDRNOTAVAIL; + else + strlcpy(p->ifp_parent, ifp0->if_xname, sizeof(p->ifp_parent)); + if_put(ifp0); + + return (error); +} + +static int +vxn_del_parent(struct vxn_softc *sc) +{ + struct ifnet *ifp = sc->sc_ifp; + + if (sc->sc_if_index0 == 0) + return (0); + + if (ISSET(ifp->if_flags, IFF_RUNNING)) + return (EBUSY); + + /* commit */ + sc->sc_if_index0 = 0; + + return (0); +} + +void +vxn_detach_hook(void *arg) +{ + struct vxn_softc *sc = arg; + struct ifnet *ifp = sc->sc_ifp; + + if (ISSET(ifp->if_flags, IFF_RUNNING)) { + vxn_down(sc); + CLR(ifp->if_flags, IFF_UP); + } + + sc->sc_if_index0 = 0; +} + +static void +vxn_link_hook(void *arg) +{ + struct vxn_softc *sc = arg; + struct ifnet *ifp = sc->sc_if; + struct ifnet *ifp0; + int link_state = LINK_STATE_UNKNOWN; + + ifp0 = if_get(sc->sc_if_index0); + if (ifp0 != NULL) + link_state = ifp0->if_link_state; + if_put(ifp0); + + if (ifp->if_link_state != link_state) { + ifp->if_link_state = link_state; + if_link_state_change(ifp); + } +} + +static inline int +vxnt_index_cmp(const struct vxnt_index *a, const struct vxnt_index *b) +{ + if (a->vi_index > b->vi_index) + return (1); + if (a->vi_index < b->vi_index) + return (-1); + + return (0); +} + +RBT_GENERATE(vxnt_indexes, vxnt_index, vi_entry, vxnt_index_cmp); + +static inline int +vxn_key_cmp(const struct vxn_key *a, const struct vxn_key *b) +{ + if (a->vk_vnetid > b->vk_vnetid) + return (1); + if (a->vk_vnetid < b->vk_vnetid) + return (-1); + + if (a->vk_proto > b->vk_proto) + return (1); + if (a->vk_proto < b->vk_proto) + return (-1); + + return (0); +} + +RBT_GENERATE(vxn_keys, vxn_key, vk_entry, vxn_key_cmp); Index: wg_noise.c =================================================================== RCS file: /cvs/src/sys/net/wg_noise.c,v retrieving revision 1.5 diff -u -p -r1.5 wg_noise.c --- wg_noise.c 21 Mar 2021 18:13:59 -0000 1.5 +++ wg_noise.c 28 Apr 2021 01:00:19 -0000 @@ -78,7 +78,6 @@ static void noise_msg_ephemeral( const uint8_t src[NOISE_PUBLIC_KEY_LEN]); static void noise_tai64n_now(uint8_t [NOISE_TIMESTAMP_LEN]); -static int noise_timer_expired(struct timespec *, time_t, long); /* Set/Get noise parameters */ void @@ -555,12 +554,10 @@ noise_remote_ready(struct noise_remote * return ret; } -int -noise_remote_encrypt(struct noise_remote *r, uint32_t *r_idx, uint64_t *nonce, - uint8_t *buf, size_t buflen) +struct noise_keypair * +noise_encrypt_begin(struct noise_remote *r, uint64_t *nonce) { struct noise_keypair *kp; - int ret = EINVAL; rw_enter_read(&r->r_keypair_lock); if ((kp = r->r_current) == NULL) @@ -579,6 +576,21 @@ noise_remote_encrypt(struct noise_remote ((*nonce = noise_counter_send(&kp->kp_ctr)) > REJECT_AFTER_MESSAGES)) goto error; + return (kp); +error: + return (NULL); +} + +int +noise_remote_encrypt(struct noise_remote *r, uint32_t *r_idx, uint64_t *nonce, + uint8_t *buf, size_t buflen) +{ + struct noise_keypair *kp; + + kp = noise_encrypt_begin(r, nonce); + if (kp == NULL) + return (EINVAL); + /* We encrypt into the same buffer, so the caller must ensure that buf * has NOISE_AUTHTAG_LEN bytes to store the MAC. The nonce and index * are passed back out to the caller through the provided data pointer. */ @@ -586,6 +598,15 @@ noise_remote_encrypt(struct noise_remote chacha20poly1305_encrypt(buf, buf, buflen, NULL, 0, *nonce, kp->kp_send); + return (noise_encrypt_commit(r, kp, *nonce)); +} + +int +noise_encrypt_commit(struct noise_remote *r, struct noise_keypair *kp, + uint64_t nonce) +{ + int ret; + /* If our values are still within tolerances, but we are approaching * the tolerances, we notify the caller with ESTALE that they should * establish a new keypair. The current keypair can continue to be used @@ -594,7 +615,7 @@ noise_remote_encrypt(struct noise_remote * - we're the initiator and our keypair is older than * REKEY_AFTER_TIME seconds */ ret = ESTALE; - if ((kp->kp_valid && *nonce >= REKEY_AFTER_MESSAGES) || + if ((kp->kp_valid && nonce >= REKEY_AFTER_MESSAGES) || (kp->kp_is_initiator && noise_timer_expired(&kp->kp_birthdate, REKEY_AFTER_TIME, 0))) goto error; @@ -605,12 +626,10 @@ error: return ret; } -int -noise_remote_decrypt(struct noise_remote *r, uint32_t r_idx, uint64_t nonce, - uint8_t *buf, size_t buflen) +struct noise_keypair * +noise_decrypt_begin(struct noise_remote *r, uint32_t r_idx, uint64_t nonce) { struct noise_keypair *kp; - int ret = EINVAL; /* We retrieve the keypair corresponding to the provided index. We * attempt the current keypair first as that is most likely. We also @@ -636,12 +655,40 @@ noise_remote_decrypt(struct noise_remote kp->kp_ctr.c_recv >= REJECT_AFTER_MESSAGES) goto error; + return (kp); + +error: + rw_exit(&r->r_keypair_lock); + return (NULL); +} + +int +noise_remote_decrypt(struct noise_remote *r, uint32_t r_idx, uint64_t nonce, + uint8_t *buf, size_t buflen) +{ + struct noise_keypair *kp; + + kp = noise_decrypt_begin(r, r_idx, nonce); + if (kp == NULL) + return (EINVAL); + /* Decrypt, then validate the counter. We don't want to validate the * counter before decrypting as we do not know the message is authentic * prior to decryption. */ if (chacha20poly1305_decrypt(buf, buf, buflen, - NULL, 0, nonce, kp->kp_recv) == 0) - goto error; + NULL, 0, nonce, kp->kp_recv) == 0) { + noise_decrypt_rollback(r); + return (EINVAL); + } + + return (noise_decrypt_commit(r, kp, r_idx, nonce)); +} + +int +noise_decrypt_commit(struct noise_remote *r, struct noise_keypair *kp, + uint32_t r_idx, uint64_t nonce) +{ + int ret = EINVAL; if (noise_counter_recv(&kp->kp_ctr, nonce) != 0) goto error; @@ -684,6 +731,12 @@ error: return ret; } +void +noise_decrypt_rollback(struct noise_remote *r) +{ + rw_exit(&r->r_keypair_lock); +} + /* Private functions - these should not be called outside this file under any * circumstances. */ static struct noise_keypair * @@ -955,7 +1008,7 @@ noise_tai64n_now(uint8_t output[NOISE_TI memcpy(output + sizeof(sec), &nsec, sizeof(nsec)); } -static int +int noise_timer_expired(struct timespec *birthdate, time_t sec, long nsec) { struct timespec uptime; Index: wg_noise.h =================================================================== RCS file: /cvs/src/sys/net/wg_noise.h,v retrieving revision 1.2 diff -u -p -r1.2 wg_noise.h --- wg_noise.h 9 Dec 2020 05:53:33 -0000 1.2 +++ wg_noise.h 28 Apr 2021 01:00:19 -0000 @@ -176,18 +176,30 @@ void noise_remote_expire_current(struct int noise_remote_ready(struct noise_remote *); +struct noise_keypair * + noise_encrypt_begin(struct noise_remote *, uint64_t *); +int noise_encrypt_commit(struct noise_remote *, struct noise_keypair *, + uint64_t); int noise_remote_encrypt( struct noise_remote *, uint32_t *r_idx, uint64_t *nonce, uint8_t *buf, size_t buflen); +struct noise_keypair * + noise_decrypt_begin(struct noise_remote *, uint32_t, uint64_t); +int noise_decrypt_commit(struct noise_remote *, struct noise_keypair *, + uint32_t, uint64_t); +void noise_decrypt_rollback(struct noise_remote *); int noise_remote_decrypt( struct noise_remote *, uint32_t r_idx, uint64_t nonce, uint8_t *buf, size_t buflen); + +int noise_timer_expired(struct timespec *, time_t, long); + #ifdef WGTEST void noise_test();